mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
cleanups
This commit is contained in:
parent
04fa825a26
commit
2feeb8c8a6
2 changed files with 10 additions and 87 deletions
|
|
@ -91,12 +91,12 @@ class TestTiny(unittest.TestCase):
|
|||
@unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images")
|
||||
def test_mnist(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm(64), Tensor.max_pool2d,
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu, Tensor.contiguous,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu, Tensor.contiguous,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d, Tensor.contiguous,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu, Tensor.contiguous,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu, Tensor.contiguous,
|
||||
nn.BatchNorm(64), Tensor.max_pool2d, Tensor.contiguous,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
# replace random weights with ones
|
||||
|
|
|
|||
|
|
@ -93,7 +93,8 @@ def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
|||
|
||||
def extract_children(ctx:RangeifyContext, x:UOp):
|
||||
if ctx.children is not None: return
|
||||
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1 and k.op not in {Ops.DEVICE, Ops.CONST}}
|
||||
# REDUCE_AXIS is fine here, should go to contig only (gate)
|
||||
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1 and any(x.op is Ops.REDUCE_AXIS for x in k.toposort())}
|
||||
def mark_children(ctx:RangeifyContext, x:UOp):
|
||||
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
|
||||
return x.replace(src=tuple(new_srcs))
|
||||
|
|
@ -109,81 +110,8 @@ rangeify_fixups = PatternMatcher([
|
|||
# const
|
||||
(UPat(Ops.CONST, name="x"), lambda x:
|
||||
x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if len(x.src) and x.src[0].op is Ops.VIEW else None),
|
||||
|
||||
# add contiguous to EXPAND
|
||||
#(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].contiguous().expand(x.arg).replace(tag=1) if x.tag is None else None),
|
||||
])
|
||||
|
||||
"""
|
||||
def record_child(ctx:RangeifyContext, c:UOp, idx:UOp):
|
||||
if c not in ctx.indexed_child: ctx.indexed_child[c] = []
|
||||
print("record child", id(idx), len(ctx.indexed_child[c]))
|
||||
if idx not in ctx.indexed_child[c]:
|
||||
ctx.indexed_child[c].append(idx)
|
||||
return idx.replace(src=(c.replace(tag=1),)+idx.src[1:])
|
||||
if len(ctx.indexed_child[c]) == c.arg:
|
||||
assert idx in ctx.indexed_child[c]
|
||||
ec = UOp(Ops.ENDCHILD, dtype=c.dtype, src=(c,))
|
||||
if ec not in ctx.ended_child: ctx.ended_child[ec] = 1
|
||||
else: ctx.ended_child[ec] += 1
|
||||
print("creating endchild", ctx.ended_child[ec])
|
||||
if c in ctx.seen_child:
|
||||
out_rngs = list(idx.src[1:])
|
||||
idx_ranges, end_ranges = ctx.seen_child[c]
|
||||
if len(idx_ranges) == 0:
|
||||
return c.src[0].index(*out_rngs)
|
||||
for i,nr in zip(idx_ranges, end_ranges):
|
||||
out_rngs[i] = nr
|
||||
return ec.index(*out_rngs).contiguous(*end_ranges, arg=ec.shape, tag=1).index(*[idx.src[1+i] for i in idx_ranges])
|
||||
else:
|
||||
# heres where we can compute everything about mismatched ranges
|
||||
all_rngs = zip(*[x.src[1:] for x in ctx.indexed_child[c]])
|
||||
out_rngs = []
|
||||
end_ranges = []
|
||||
idx_ranges = []
|
||||
for i,r in enumerate(all_rngs):
|
||||
if all_same(r):
|
||||
out_rngs.append(r[0])
|
||||
else:
|
||||
out_rngs.append(UOp.range(dtypes.int, c.shape[i], (ctx.idx, AxisType.LOOP)))
|
||||
ctx.idx += 1
|
||||
end_ranges.append(out_rngs[-1])
|
||||
idx_ranges.append(i)
|
||||
if len(end_ranges) == 0:
|
||||
# safe to remove child right away
|
||||
#return c.src[0].index(*out_rngs)
|
||||
return ec.index(*out_rngs)
|
||||
else:
|
||||
ctx.seen_child[c] = (idx_ranges, end_ranges)
|
||||
return ec.index(*out_rngs).contiguous(*end_ranges, arg=ec.shape, tag=1).index(*[idx.src[1+i] for i in idx_ranges])
|
||||
|
||||
def child_check(ctx:RangeifyContext, sink:UOp):
|
||||
subs = {}
|
||||
for x in ctx.ended_child:
|
||||
if ctx.ended_child[x] == x.src[0].arg:
|
||||
print("sub")
|
||||
subs[x] = x.src[0].src[0]
|
||||
ctx.ended_child.clear()
|
||||
return sink.substitute(subs)
|
||||
"""
|
||||
|
||||
"""
|
||||
def visit_child(ctx:RangeifyContext, x:UOp):
|
||||
print(f"visit CHILD {x.arg} bottom up")
|
||||
if x.src[0] not in ctx.seen_children: ctx.seen_children[x.src[0]] = set()
|
||||
ctx.seen_children[x.src[0]].add(x.arg)
|
||||
if len(ctx.seen_children[x.src[0]]) != x.src[0].arg: raise RewriteNotReady
|
||||
print("READY")
|
||||
|
||||
def visit_children(ctx:RangeifyContext, x:UOp):
|
||||
if x.tag == 1: return None
|
||||
if len(ctx.seen_children[x]) != x.arg:
|
||||
print("visit CHILDREN bottom up -- not ready")
|
||||
raise RewriteNotReady
|
||||
print("visit CHILDREN bottom up -- READY")
|
||||
return x.replace(tag=1)
|
||||
"""
|
||||
|
||||
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
||||
print(f"visit CHILD {x.arg} bottom up")
|
||||
if c not in ctx.seen_children: ctx.seen_children[c] = {}
|
||||
|
|
@ -215,17 +143,11 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
|||
pm_rangeify = PatternMatcher([
|
||||
# if there are new ended children, tag the SINK
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
|
||||
#(UPat(Ops.SINK, name="sink"), child_check),
|
||||
#(UPat(Ops.CHILD, name="x"), visit_child),
|
||||
#(UPat(Ops.CHILDREN, name="x"), visit_children),
|
||||
|
||||
|
||||
# if there's an INDEX it can support partial contig
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, name="x"),), allow_any_len=True, name="idx"), map_contiguous),
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), map_contiguous),
|
||||
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
||||
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
|
||||
|
|
@ -238,9 +160,10 @@ pm_rangeify = PatternMatcher([
|
|||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
|
||||
|
||||
# move MAP through elementwise ALU
|
||||
# move MAP through elementwise ALU / reduce. these are the items with cost
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE})),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
||||
|
||||
# CONST can't have axes. remove srcs when we idx
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue