mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
8 commits
master
...
mega_lower
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
954f30a1c1 | ||
|
|
58e79f3fc8 | ||
|
|
aee4ebe52f | ||
|
|
cb9bcc60de |
||
|
|
92160f1cf1 | ||
|
|
2584fe8907 | ||
|
|
8dff2c1375 | ||
|
|
3cff1a6b13 |
5 changed files with 166 additions and 4 deletions
|
|
@ -33,6 +33,35 @@ class TestTiny(unittest.TestCase):
|
||||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||||
|
|
||||||
|
def test_double_gemm(self, N=64, BS=1):
|
||||||
|
a = Tensor.ones(BS,N,N).contiguous().realize()
|
||||||
|
b = Tensor.eye(N).contiguous().realize()
|
||||||
|
c = Tensor.eye(N).contiguous().realize()
|
||||||
|
d = Tensor.eye(N).contiguous().realize()
|
||||||
|
e = Tensor.eye(N).contiguous().realize()
|
||||||
|
f = Tensor.eye(N).contiguous().realize()
|
||||||
|
g = Tensor.eye(N).contiguous().realize()
|
||||||
|
out = ((a@b@c@d).contiguous()@e@f@g).contiguous().realize()
|
||||||
|
self.assertListEqual(out.flatten().tolist(), [1.0]*(BS*N*N))
|
||||||
|
|
||||||
|
def test_double_gemm_bs(self, N=64, BS=1): self.test_double_gemm(BS=4)
|
||||||
|
|
||||||
|
def test_conv2d(self):
|
||||||
|
N = 64
|
||||||
|
a = Tensor.ones(1,4,N,N).contiguous().realize()
|
||||||
|
w1 = Tensor.ones(16,4,3,3).contiguous().realize()
|
||||||
|
out = a.conv2d(w1).contiguous().realize()
|
||||||
|
|
||||||
|
def test_double_conv2d(self):
|
||||||
|
N = 64
|
||||||
|
a = Tensor.ones(1,4,N,N).contiguous().realize()
|
||||||
|
w1 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||||
|
w2 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||||
|
w3 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||||
|
w4 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||||
|
w5 = Tensor.ones(4,4,3,3).contiguous().realize()
|
||||||
|
out = a.conv2d(w1).conv2d(w2).conv2d(w3).conv2d(w4).conv2d(w5).contiguous().realize()
|
||||||
|
|
||||||
# *** randomness ***
|
# *** randomness ***
|
||||||
|
|
||||||
def test_random(self):
|
def test_random(self):
|
||||||
|
|
|
||||||
|
|
@ -90,9 +90,9 @@ class BlockContext:
|
||||||
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
|
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
|
||||||
|
|
||||||
# RANGE/IF add to the next ctx
|
# RANGE/IF add to the next ctx
|
||||||
# STORE/ASSIGN subtract from the next ctx
|
# STORE/REDUCE_AXIS subtract from the next ctx
|
||||||
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
|
||||||
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
elif u.op in {Ops.STORE, Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
# ***** make blocks *****
|
# ***** make blocks *****
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,9 @@ class CStyleLanguage(Renderer):
|
||||||
c: defaultdict[str, int] = defaultdict(int)
|
c: defaultdict[str, int] = defaultdict(int)
|
||||||
name = "test"
|
name = "test"
|
||||||
for u in uops:
|
for u in uops:
|
||||||
if u.op is Ops.NOOP: continue
|
if u.op is Ops.NOOP:
|
||||||
|
if len(u.src): r[u] = r[u.src[0]]
|
||||||
|
continue
|
||||||
if u.op is Ops.SINK:
|
if u.op is Ops.SINK:
|
||||||
if u.arg is not None: name = u.arg.function_name
|
if u.arg is not None: name = u.arg.function_name
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
||||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||||
from tinygrad.uop.symbolic import symbolic_simple
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP, argsort
|
||||||
from tinygrad.dtype import ImageDType, dtypes
|
from tinygrad.dtype import ImageDType, dtypes
|
||||||
from tinygrad.schedule.multi import multi_pm
|
from tinygrad.schedule.multi import multi_pm
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
|
|
@ -417,6 +417,121 @@ finalize_contiguous = PatternMatcher([
|
||||||
|
|
||||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||||
|
|
||||||
|
def contiguous_create_ranges(ctx:list[int], x:UOp):
|
||||||
|
if len(x.src) != 1: return None
|
||||||
|
ranges = []
|
||||||
|
for s in x.shape:
|
||||||
|
if resolve(s!=1):
|
||||||
|
ranges.append(UOp.range(dtypes.int, s, ctx[0]))
|
||||||
|
ctx[0] += 1
|
||||||
|
else:
|
||||||
|
ranges.append(UOp.const(dtypes.int, 0))
|
||||||
|
mm = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
|
||||||
|
buf = UOp.new_buffer(x.device, prod(x.shape), x.dtype).reshape(x.shape)
|
||||||
|
mm2 = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(buf,)+tuple(ranges))
|
||||||
|
return UOp(Ops.STORE, src=(mm2, mm)+tuple(ranges))
|
||||||
|
#return x.replace(src=(mm,)+tuple(ranges))
|
||||||
|
|
||||||
|
def map_reshape(x:UOp):
|
||||||
|
# don't push on the final buffer reshape for readable graph
|
||||||
|
#if x.src[0].src[0].op is Ops.BUFFER: return None
|
||||||
|
acc = 1
|
||||||
|
to_sum = []
|
||||||
|
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
|
||||||
|
to_sum.append(acc*src)
|
||||||
|
acc *= s
|
||||||
|
mish = sum(to_sum)
|
||||||
|
ret = []
|
||||||
|
for s in x.src[0].src[0].shape[::-1]:
|
||||||
|
if resolve(s!=1):
|
||||||
|
ret.append(mish % s)
|
||||||
|
mish //= s
|
||||||
|
else:
|
||||||
|
ret.append(UOp.const(dtypes.int, 0))
|
||||||
|
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
|
||||||
|
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
|
||||||
|
# generic
|
||||||
|
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
|
||||||
|
|
||||||
|
def map_reduce(ctx:list[int], x:UOp):
|
||||||
|
rngs = list(x.src[1:])
|
||||||
|
r = x.src[0]
|
||||||
|
new_ranges = []
|
||||||
|
for i,s in enumerate(r.src[0].shape):
|
||||||
|
if i in r.arg[1]:
|
||||||
|
assert rngs[i].op == Ops.CONST
|
||||||
|
rngs[i] = UOp.range(dtypes.int, s, ctx[0])
|
||||||
|
new_ranges.append(rngs[i])
|
||||||
|
ctx[0] += 1
|
||||||
|
mm = UOp(Ops.MAP, r.src[0].dtype, src=(r.src[0],)+tuple(rngs))
|
||||||
|
return UOp(Ops.REDUCE_AXIS, r.dtype, src=(mm,)+tuple(new_ranges), arg=r.arg)
|
||||||
|
|
||||||
|
def map_permute(x:UOp):
|
||||||
|
ret = x.src[1:]
|
||||||
|
# argsort or not?
|
||||||
|
perm = argsort(x.src[0].arg)
|
||||||
|
print(perm, x.src[0].arg)
|
||||||
|
ret = tuple([ret[p] for p in perm])
|
||||||
|
print(x.src[0].src[0].shape, ret)
|
||||||
|
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
|
||||||
|
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
|
||||||
|
|
||||||
|
def map_expand(x:UOp):
|
||||||
|
r = x.src[0]
|
||||||
|
inp_shape, exp_shape = x.src[0].src[0].shape, x.src[0].shape
|
||||||
|
ret = list(x.src[1:])
|
||||||
|
exp_ranges = []
|
||||||
|
for i,(x,y) in enumerate(zip(inp_shape, exp_shape)):
|
||||||
|
if x != y:
|
||||||
|
exp_ranges.append(ret[i])
|
||||||
|
ret[i] = UOp.const(dtypes.int, 0)
|
||||||
|
mm = UOp(Ops.MAP, r.dtype, src=(r.src[0],)+tuple(ret))
|
||||||
|
return UOp(Ops.EXPAND, r.dtype, src=(mm,)+tuple(exp_ranges), arg=r.arg)
|
||||||
|
|
||||||
|
def map_shrink(ctx:list[int], x:UOp):
|
||||||
|
r = x.src[0]
|
||||||
|
ret = list(x.src[1:])
|
||||||
|
for i,(s,(ss,se)) in enumerate(zip(r.src[0].shape, r.arg)):
|
||||||
|
assert ss == 0, "add to range?"
|
||||||
|
if se-ss != s and False:
|
||||||
|
new_ret_i = [ret[i]]
|
||||||
|
if ss != 0:
|
||||||
|
new_ret_i = [UOp.range(dtypes.int, ss, ctx[0])] + new_ret_i
|
||||||
|
ctx[0] += 1
|
||||||
|
if se != s:
|
||||||
|
new_ret_i = new_ret_i + [UOp.range(dtypes.int, s-se, ctx[0])]
|
||||||
|
ctx[0] += 1
|
||||||
|
ret[i] = UOp(Ops.CATRANGE, src=tuple(new_ret_i))
|
||||||
|
mm = UOp(Ops.MAP, r.dtype, src=(r.src[0],)+tuple(ret))
|
||||||
|
#return mm
|
||||||
|
# TODO: put the ranges on the shrink?
|
||||||
|
return UOp(Ops.SHRINK, r.dtype, src=(mm,), arg=r.arg)
|
||||||
|
|
||||||
|
index_pushing = PatternMatcher([
|
||||||
|
(UPat(Ops.CONTIGUOUS, name="x"), contiguous_create_ranges),
|
||||||
|
(UPat(Ops.MAP, src=(UPat(Ops.RESHAPE),), allow_any_len=True, name="x"), map_reshape),
|
||||||
|
(UPat(Ops.MAP, src=(UPat(Ops.PERMUTE),), allow_any_len=True, name="x"), map_permute),
|
||||||
|
(UPat(Ops.MAP, src=(UPat(Ops.EXPAND),), allow_any_len=True, name="x"), map_expand),
|
||||||
|
(UPat(Ops.MAP, src=(UPat(Ops.SHRINK),), allow_any_len=True, name="x"), map_shrink),
|
||||||
|
(UPat(Ops.MAP, src=(UPat(Ops.REDUCE_AXIS),), allow_any_len=True, name="x"), map_reduce),
|
||||||
|
# move MAP through elementwise ALU
|
||||||
|
(UPat(Ops.MAP, src=(UPat(GroupOp.Elementwise),), allow_any_len=True, name="x"),
|
||||||
|
lambda x: x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+x.src[1:]) for s in x.src[0].src]))),
|
||||||
|
# MAP on STORE is NOOP
|
||||||
|
(UPat(Ops.MAP, src=(UPat(Ops.STORE),), allow_any_len=True, name="x"), lambda x: x.src[0]),
|
||||||
|
])
|
||||||
|
|
||||||
|
fix_buffers = PatternMatcher([
|
||||||
|
(UPat(Ops.BUFFER, name="x"), lambda x: UOp(Ops.DEFINE_GLOBAL, dtype=x.dtype.ptr(x.arg), arg=x.src[0].arg)),
|
||||||
|
(UPat(Ops.MAP, name="x"), lambda x: x.replace(op=Ops.INDEX, dtype=x.src[0].dtype).load()),
|
||||||
|
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), name="x", allow_any_len=True), lambda x: x.replace(src=(x.src[0].src[0],)+x.src[1:])),
|
||||||
|
(UPat((Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE), name="x"), lambda x: x.src[0]),
|
||||||
|
# do EXPANDs need to track the ranges they end?
|
||||||
|
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0]),
|
||||||
|
#(UPat(Ops.EXPAND, name="x"), lambda x: x.replace(arg=None, op=Ops.NOOP)),
|
||||||
|
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: x.replace(op=Ops.REDUCE, arg=x.arg[0])),
|
||||||
|
])
|
||||||
|
|
||||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
|
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -428,6 +543,20 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||||
Returns:
|
Returns:
|
||||||
Map transforming each UOp in the sink to the Ops.KERNEL graph.
|
Map transforming each UOp in the sink to the Ops.KERNEL graph.
|
||||||
"""
|
"""
|
||||||
|
pushed = graph_rewrite(sink, index_pushing, ctx=[0], name="index pushing", bottom_up=True)
|
||||||
|
pushed = graph_rewrite(pushed, fix_buffers, name="fix buffers")
|
||||||
|
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
|
||||||
|
pushed = graph_rewrite(pushed, pm_reduce, ctx=ReduceContext(), name="remove reduce")
|
||||||
|
from tinygrad.codegen.linearize import block_create, BlockContext, pm_blockend_merge, block_merge, pm_finalize
|
||||||
|
pushed = graph_rewrite(pushed, block_create, ctx=BlockContext.from_sink(pushed), name="block create", bottom_up=True)
|
||||||
|
pushed = graph_rewrite(pushed, pm_blockend_merge, name="blockend merge")
|
||||||
|
pushed = graph_rewrite(pushed, block_merge, name="block merge")
|
||||||
|
pushed = graph_rewrite(pushed, pm_finalize, name="finalize")
|
||||||
|
from tinygrad.device import Device
|
||||||
|
try:
|
||||||
|
print(Device['CPU'].renderer.render(pushed.arg.lst))
|
||||||
|
except Exception as e:
|
||||||
|
print("render fail", e)
|
||||||
|
|
||||||
# multi + merge_views + simplify
|
# multi + merge_views + simplify
|
||||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ class FastEnum(IntEnum):
|
||||||
class Ops(FastEnum):
|
class Ops(FastEnum):
|
||||||
# uops that aren't rendered
|
# uops that aren't rendered
|
||||||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto() # noqa: E702
|
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto() # noqa: E702
|
||||||
|
MAP = auto(); CATRANGE = auto()
|
||||||
|
|
||||||
# buffer ops
|
# buffer ops
|
||||||
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||||
|
|
@ -82,6 +83,7 @@ class GroupOp:
|
||||||
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
|
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
|
||||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||||
ALU = set.union(Unary, Binary, Ternary)
|
ALU = set.union(Unary, Binary, Ternary)
|
||||||
|
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
|
||||||
|
|
||||||
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue