mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
linearize works
This commit is contained in:
parent
cb9bcc60de
commit
aee4ebe52f
4 changed files with 45 additions and 11 deletions
|
|
@ -38,7 +38,10 @@ class TestTiny(unittest.TestCase):
|
|||
b = Tensor.eye(N).contiguous().realize()
|
||||
c = Tensor.eye(N).contiguous().realize()
|
||||
d = Tensor.eye(N).contiguous().realize()
|
||||
out = (((a@b).relu()@c).relu()@d).contiguous().realize()
|
||||
e = Tensor.eye(N).contiguous().realize()
|
||||
f = Tensor.eye(N).contiguous().realize()
|
||||
g = Tensor.eye(N).contiguous().realize()
|
||||
out = (a@b@c@d@e@f@g).contiguous().realize()
|
||||
self.assertListEqual(out.flatten().tolist(), [1.0]*(N*N))
|
||||
|
||||
def test_conv2d(self):
|
||||
|
|
@ -50,10 +53,12 @@ class TestTiny(unittest.TestCase):
|
|||
def test_double_conv2d(self):
|
||||
N = 64
|
||||
a = Tensor.ones(1,4,N,N).contiguous().realize()
|
||||
w1 = Tensor.ones(16,4,3,3).contiguous().realize()
|
||||
w2 = Tensor.ones(24,16,3,3).contiguous().realize()
|
||||
w3 = Tensor.ones(32,24,3,3).contiguous().realize()
|
||||
out = a.conv2d(w1).conv2d(w2).conv2d(w3).contiguous().realize()
|
||||
w1 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||
w2 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||
w3 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||
w4 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||
w5 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||
out = a.conv2d(w1).conv2d(w2).conv2d(w3).conv2d(w4).conv2d(w5).contiguous().realize()
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
|
|
|
|||
|
|
@ -90,9 +90,9 @@ class BlockContext:
|
|||
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
|
||||
|
||||
# RANGE/IF add to the next ctx
|
||||
# STORE/ASSIGN subtract from the next ctx
|
||||
# STORE/REDUCE_AXIS subtract from the next ctx
|
||||
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
||||
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
||||
elif u.op in {Ops.STORE, Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
||||
return ctx
|
||||
|
||||
# ***** make blocks *****
|
||||
|
|
|
|||
|
|
@ -141,7 +141,9 @@ class CStyleLanguage(Renderer):
|
|||
c: defaultdict[str, int] = defaultdict(int)
|
||||
name = "test"
|
||||
for u in uops:
|
||||
if u.op is Ops.NOOP: continue
|
||||
if u.op is Ops.NOOP:
|
||||
if len(u.src): r[u] = r[u.src[0]]
|
||||
continue
|
||||
if u.op is Ops.SINK:
|
||||
if u.arg is not None: name = u.arg.function_name
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -422,11 +422,14 @@ def contiguous_create_ranges(ctx:list[int], x:UOp):
|
|||
ranges = [UOp.range(dtypes.int, s, ctx[0]+i) if resolve(s!=1) else UOp.const(dtypes.int, 0) for i,s in enumerate(x.shape)]
|
||||
ctx[0] += len(ranges)
|
||||
mm = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
|
||||
return x.replace(src=(mm,)+tuple(ranges))
|
||||
buf = UOp.new_buffer(x.device, prod(x.shape), x.dtype)
|
||||
mm2 = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(buf,)+tuple(ranges))
|
||||
return UOp(Ops.STORE, src=(mm2, mm)+tuple(ranges))
|
||||
#return x.replace(src=(mm,)+tuple(ranges))
|
||||
|
||||
def map_reshape(x:UOp):
|
||||
# don't push on the final buffer reshape for readable graph
|
||||
if x.src[0].src[0].op is Ops.BUFFER: return None
|
||||
#if x.src[0].src[0].op is Ops.BUFFER: return None
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
|
||||
|
|
@ -441,6 +444,7 @@ def map_reshape(x:UOp):
|
|||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
|
||||
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
|
||||
# generic
|
||||
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
|
||||
|
||||
|
|
@ -464,6 +468,7 @@ def map_permute(x:UOp):
|
|||
print(perm, x.src[0].arg)
|
||||
ret = tuple([ret[p] for p in perm])
|
||||
print(x.src[0].src[0].shape, ret)
|
||||
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
|
||||
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
|
||||
|
||||
def map_expand(x:UOp):
|
||||
|
|
@ -493,6 +498,7 @@ def map_shrink(ctx:list[int], x:UOp):
|
|||
ctx[0] += 1
|
||||
ret[i] = UOp(Ops.CATRANGE, src=tuple(new_ret_i))
|
||||
mm = UOp(Ops.MAP, r.dtype, src=(r.src[0],)+tuple(ret))
|
||||
#return mm
|
||||
# TODO: put the ranges on the shrink?
|
||||
return UOp(Ops.SHRINK, r.dtype, src=(mm,), arg=r.arg)
|
||||
|
||||
|
|
@ -508,6 +514,14 @@ index_pushing = PatternMatcher([
|
|||
lambda x: x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+x.src[1:]) for s in x.src[0].src]))),
|
||||
])
|
||||
|
||||
fix_buffers = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="x"), lambda x: UOp(Ops.DEFINE_GLOBAL, dtype=x.dtype.ptr(x.arg), arg=x.src[0].arg)),
|
||||
(UPat(Ops.MAP, name="x"), lambda x: x.replace(op=Ops.INDEX, dtype=x.src[0].dtype)),
|
||||
(UPat((Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.EXPAND, name="x"), lambda x: x.replace(arg=None, op=Ops.NOOP)),
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: x.replace(op=Ops.REDUCE, arg=x.arg[0])),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
"""
|
||||
|
|
@ -519,7 +533,20 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
Returns:
|
||||
Map transforming each UOp in the sink to the Ops.KERNEL graph.
|
||||
"""
|
||||
graph_rewrite(sink, index_pushing, ctx=[0], name="index pushing", bottom_up=True)
|
||||
pushed = graph_rewrite(sink, index_pushing, ctx=[0], name="index pushing", bottom_up=True)
|
||||
pushed = graph_rewrite(pushed, fix_buffers, name="fix buffers")
|
||||
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
|
||||
pushed = graph_rewrite(pushed, pm_reduce, ctx=ReduceContext(), name="remove reduce")
|
||||
from tinygrad.codegen.linearize import block_create, BlockContext, pm_blockend_merge, block_merge, pm_finalize
|
||||
pushed = graph_rewrite(pushed, block_create, ctx=BlockContext.from_sink(pushed), name="block create", bottom_up=True)
|
||||
pushed = graph_rewrite(pushed, pm_blockend_merge, name="blockend merge")
|
||||
pushed = graph_rewrite(pushed, block_merge, name="block merge")
|
||||
pushed = graph_rewrite(pushed, pm_finalize, name="finalize")
|
||||
from tinygrad.device import Device
|
||||
try:
|
||||
print(Device['CPU'].renderer.render(pushed.arg.lst))
|
||||
except Exception as e:
|
||||
print("render fail", e)
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue