This commit is contained in:
George Hotz 2025-08-11 08:57:30 -07:00
commit 2feeb8c8a6
2 changed files with 10 additions and 87 deletions

View file

@ -91,12 +91,12 @@ class TestTiny(unittest.TestCase):
@unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images")
def test_mnist(self):
layers = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu,
nn.BatchNorm(32), Tensor.max_pool2d,
nn.Conv2d(32, 64, 3), Tensor.relu,
nn.Conv2d(64, 64, 3), Tensor.relu,
nn.BatchNorm(64), Tensor.max_pool2d,
nn.Conv2d(1, 32, 5), Tensor.relu, Tensor.contiguous,
nn.Conv2d(32, 32, 5), Tensor.relu, Tensor.contiguous,
nn.BatchNorm(32), Tensor.max_pool2d, Tensor.contiguous,
nn.Conv2d(32, 64, 3), Tensor.relu, Tensor.contiguous,
nn.Conv2d(64, 64, 3), Tensor.relu, Tensor.contiguous,
nn.BatchNorm(64), Tensor.max_pool2d, Tensor.contiguous,
lambda x: x.flatten(1), nn.Linear(576, 10)]
# replace random weights with ones

View file

@ -93,7 +93,8 @@ def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
def extract_children(ctx:RangeifyContext, x:UOp):
if ctx.children is not None: return
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1 and k.op not in {Ops.DEVICE, Ops.CONST}}
# REDUCE_AXIS is fine here, should go to contig only (gate)
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1 and any(x.op is Ops.REDUCE_AXIS for x in k.toposort())}
def mark_children(ctx:RangeifyContext, x:UOp):
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
return x.replace(src=tuple(new_srcs))
@ -109,81 +110,8 @@ rangeify_fixups = PatternMatcher([
# const
(UPat(Ops.CONST, name="x"), lambda x:
x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if len(x.src) and x.src[0].op is Ops.VIEW else None),
# add contiguous to EXPAND
#(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].contiguous().expand(x.arg).replace(tag=1) if x.tag is None else None),
])
"""
def record_child(ctx:RangeifyContext, c:UOp, idx:UOp):
if c not in ctx.indexed_child: ctx.indexed_child[c] = []
print("record child", id(idx), len(ctx.indexed_child[c]))
if idx not in ctx.indexed_child[c]:
ctx.indexed_child[c].append(idx)
return idx.replace(src=(c.replace(tag=1),)+idx.src[1:])
if len(ctx.indexed_child[c]) == c.arg:
assert idx in ctx.indexed_child[c]
ec = UOp(Ops.ENDCHILD, dtype=c.dtype, src=(c,))
if ec not in ctx.ended_child: ctx.ended_child[ec] = 1
else: ctx.ended_child[ec] += 1
print("creating endchild", ctx.ended_child[ec])
if c in ctx.seen_child:
out_rngs = list(idx.src[1:])
idx_ranges, end_ranges = ctx.seen_child[c]
if len(idx_ranges) == 0:
return c.src[0].index(*out_rngs)
for i,nr in zip(idx_ranges, end_ranges):
out_rngs[i] = nr
return ec.index(*out_rngs).contiguous(*end_ranges, arg=ec.shape, tag=1).index(*[idx.src[1+i] for i in idx_ranges])
else:
# heres where we can compute everything about mismatched ranges
all_rngs = zip(*[x.src[1:] for x in ctx.indexed_child[c]])
out_rngs = []
end_ranges = []
idx_ranges = []
for i,r in enumerate(all_rngs):
if all_same(r):
out_rngs.append(r[0])
else:
out_rngs.append(UOp.range(dtypes.int, c.shape[i], (ctx.idx, AxisType.LOOP)))
ctx.idx += 1
end_ranges.append(out_rngs[-1])
idx_ranges.append(i)
if len(end_ranges) == 0:
# safe to remove child right away
#return c.src[0].index(*out_rngs)
return ec.index(*out_rngs)
else:
ctx.seen_child[c] = (idx_ranges, end_ranges)
return ec.index(*out_rngs).contiguous(*end_ranges, arg=ec.shape, tag=1).index(*[idx.src[1+i] for i in idx_ranges])
def child_check(ctx:RangeifyContext, sink:UOp):
subs = {}
for x in ctx.ended_child:
if ctx.ended_child[x] == x.src[0].arg:
print("sub")
subs[x] = x.src[0].src[0]
ctx.ended_child.clear()
return sink.substitute(subs)
"""
"""
def visit_child(ctx:RangeifyContext, x:UOp):
print(f"visit CHILD {x.arg} bottom up")
if x.src[0] not in ctx.seen_children: ctx.seen_children[x.src[0]] = set()
ctx.seen_children[x.src[0]].add(x.arg)
if len(ctx.seen_children[x.src[0]]) != x.src[0].arg: raise RewriteNotReady
print("READY")
def visit_children(ctx:RangeifyContext, x:UOp):
if x.tag == 1: return None
if len(ctx.seen_children[x]) != x.arg:
print("visit CHILDREN bottom up -- not ready")
raise RewriteNotReady
print("visit CHILDREN bottom up -- READY")
return x.replace(tag=1)
"""
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
print(f"visit CHILD {x.arg} bottom up")
if c not in ctx.seen_children: ctx.seen_children[c] = {}
@ -215,17 +143,11 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
pm_rangeify = PatternMatcher([
# if there are new ended children, tag the SINK
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
#(UPat(Ops.SINK, name="sink"), child_check),
#(UPat(Ops.CHILD, name="x"), visit_child),
#(UPat(Ops.CHILDREN, name="x"), visit_children),
# if there's an INDEX it can support partial contig
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, name="x"),), allow_any_len=True, name="idx"), map_contiguous),
(UPat(Ops.CONTIGUOUS, name="x"), map_contiguous),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
# this is like the definitions of these
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
@ -238,9 +160,10 @@ pm_rangeify = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
# move MAP through elementwise ALU
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE})),), allow_any_len=True, name="x"),
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
# CONST can't have axes. remove srcs when we idx
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),