linearize works

This commit is contained in:
George Hotz 2025-07-31 19:37:38 -07:00
commit aee4ebe52f
4 changed files with 45 additions and 11 deletions

View file

@ -38,7 +38,10 @@ class TestTiny(unittest.TestCase):
b = Tensor.eye(N).contiguous().realize()
c = Tensor.eye(N).contiguous().realize()
d = Tensor.eye(N).contiguous().realize()
out = (((a@b).relu()@c).relu()@d).contiguous().realize()
e = Tensor.eye(N).contiguous().realize()
f = Tensor.eye(N).contiguous().realize()
g = Tensor.eye(N).contiguous().realize()
out = (a@b@c@d@e@f@g).contiguous().realize()
self.assertListEqual(out.flatten().tolist(), [1.0]*(N*N))
def test_conv2d(self):
@ -50,10 +53,12 @@ class TestTiny(unittest.TestCase):
def test_double_conv2d(self):
N = 64
a = Tensor.ones(1,4,N,N).contiguous().realize()
w1 = Tensor.ones(16,4,3,3).contiguous().realize()
w2 = Tensor.ones(24,16,3,3).contiguous().realize()
w3 = Tensor.ones(32,24,3,3).contiguous().realize()
out = a.conv2d(w1).conv2d(w2).conv2d(w3).contiguous().realize()
w1 = Tensor.ones(4,4,3,3).contiguous().realize()
w2 = Tensor.ones(4,4,3,3).contiguous().realize()
w3 = Tensor.ones(4,4,3,3).contiguous().realize()
w4 = Tensor.ones(4,4,3,3).contiguous().realize()
w5 = Tensor.ones(4,4,3,3).contiguous().realize()
out = a.conv2d(w1).conv2d(w2).conv2d(w3).conv2d(w4).conv2d(w5).contiguous().realize()
# *** randomness ***

View file

@ -90,9 +90,9 @@ class BlockContext:
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
# RANGE/IF add to the next ctx
# STORE/ASSIGN subtract from the next ctx
# STORE/REDUCE_AXIS subtract from the next ctx
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
elif u.op in {Ops.STORE, Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
return ctx
# ***** make blocks *****

View file

@ -141,7 +141,9 @@ class CStyleLanguage(Renderer):
c: defaultdict[str, int] = defaultdict(int)
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.NOOP:
if len(u.src): r[u] = r[u.src[0]]
continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue

View file

@ -422,11 +422,14 @@ def contiguous_create_ranges(ctx:list[int], x:UOp):
ranges = [UOp.range(dtypes.int, s, ctx[0]+i) if resolve(s!=1) else UOp.const(dtypes.int, 0) for i,s in enumerate(x.shape)]
ctx[0] += len(ranges)
mm = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
return x.replace(src=(mm,)+tuple(ranges))
buf = UOp.new_buffer(x.device, prod(x.shape), x.dtype)
mm2 = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(buf,)+tuple(ranges))
return UOp(Ops.STORE, src=(mm2, mm)+tuple(ranges))
#return x.replace(src=(mm,)+tuple(ranges))
def map_reshape(x:UOp):
# don't push on the final buffer reshape for readable graph
if x.src[0].src[0].op is Ops.BUFFER: return None
#if x.src[0].src[0].op is Ops.BUFFER: return None
acc = 1
to_sum = []
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
@ -441,6 +444,7 @@ def map_reshape(x:UOp):
else:
ret.append(UOp.const(dtypes.int, 0))
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
# generic
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
@ -464,6 +468,7 @@ def map_permute(x:UOp):
print(perm, x.src[0].arg)
ret = tuple([ret[p] for p in perm])
print(x.src[0].src[0].shape, ret)
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
def map_expand(x:UOp):
@ -493,6 +498,7 @@ def map_shrink(ctx:list[int], x:UOp):
ctx[0] += 1
ret[i] = UOp(Ops.CATRANGE, src=tuple(new_ret_i))
mm = UOp(Ops.MAP, r.dtype, src=(r.src[0],)+tuple(ret))
#return mm
# TODO: put the ranges on the shrink?
return UOp(Ops.SHRINK, r.dtype, src=(mm,), arg=r.arg)
@ -508,6 +514,14 @@ index_pushing = PatternMatcher([
lambda x: x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+x.src[1:]) for s in x.src[0].src]))),
])
fix_buffers = PatternMatcher([
(UPat(Ops.BUFFER, name="x"), lambda x: UOp(Ops.DEFINE_GLOBAL, dtype=x.dtype.ptr(x.arg), arg=x.src[0].arg)),
(UPat(Ops.MAP, name="x"), lambda x: x.replace(op=Ops.INDEX, dtype=x.src[0].dtype)),
(UPat((Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE), name="x"), lambda x: x.src[0]),
(UPat(Ops.EXPAND, name="x"), lambda x: x.replace(arg=None, op=Ops.NOOP)),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: x.replace(op=Ops.REDUCE, arg=x.arg[0])),
])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
@ -519,7 +533,20 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
Returns:
Map transforming each UOp in the sink to the Ops.KERNEL graph.
"""
graph_rewrite(sink, index_pushing, ctx=[0], name="index pushing", bottom_up=True)
pushed = graph_rewrite(sink, index_pushing, ctx=[0], name="index pushing", bottom_up=True)
pushed = graph_rewrite(pushed, fix_buffers, name="fix buffers")
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
pushed = graph_rewrite(pushed, pm_reduce, ctx=ReduceContext(), name="remove reduce")
from tinygrad.codegen.linearize import block_create, BlockContext, pm_blockend_merge, block_merge, pm_finalize
pushed = graph_rewrite(pushed, block_create, ctx=BlockContext.from_sink(pushed), name="block create", bottom_up=True)
pushed = graph_rewrite(pushed, pm_blockend_merge, name="blockend merge")
pushed = graph_rewrite(pushed, block_merge, name="block merge")
pushed = graph_rewrite(pushed, pm_finalize, name="finalize")
from tinygrad.device import Device
try:
print(Device['CPU'].renderer.render(pushed.arg.lst))
except Exception as e:
print("render fail", e)
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")