Compare commits

...

8 commits

Author SHA1 Message Date
George Hotz
954f30a1c1 work 2025-07-31 21:11:28 -07:00
George Hotz
58e79f3fc8 bugfixes 2025-07-31 19:57:15 -07:00
George Hotz
aee4ebe52f linearize works 2025-07-31 19:37:38 -07:00
George Hotz
cb9bcc60de
Merge branch 'master' into mega_lowerer 2025-07-31 18:12:47 -07:00
George Hotz
92160f1cf1 stuff 2025-07-31 13:50:04 -07:00
George Hotz
2584fe8907 naive shrink pushes everything left 2025-07-30 17:12:38 -07:00
George Hotz
8dff2c1375 triple gemm 2025-07-30 17:03:02 -07:00
George Hotz
3cff1a6b13 run the lowerer on the big graph 2025-07-30 15:23:43 -07:00
5 changed files with 166 additions and 4 deletions

View file

@ -33,6 +33,35 @@ class TestTiny(unittest.TestCase):
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
def test_double_gemm(self, N=64, BS=1):
a = Tensor.ones(BS,N,N).contiguous().realize()
b = Tensor.eye(N).contiguous().realize()
c = Tensor.eye(N).contiguous().realize()
d = Tensor.eye(N).contiguous().realize()
e = Tensor.eye(N).contiguous().realize()
f = Tensor.eye(N).contiguous().realize()
g = Tensor.eye(N).contiguous().realize()
out = ((a@b@c@d).contiguous()@e@f@g).contiguous().realize()
self.assertListEqual(out.flatten().tolist(), [1.0]*(BS*N*N))
def test_double_gemm_bs(self, N=64, BS=1): self.test_double_gemm(BS=4)
def test_conv2d(self):
N = 64
a = Tensor.ones(1,4,N,N).contiguous().realize()
w1 = Tensor.ones(16,4,3,3).contiguous().realize()
out = a.conv2d(w1).contiguous().realize()
def test_double_conv2d(self):
N = 64
a = Tensor.ones(1,4,N,N).contiguous().realize()
w1 = Tensor.ones(4,4,3,3).contiguous().realize()
w2 = Tensor.ones(4,4,3,3).contiguous().realize()
w3 = Tensor.ones(4,4,3,3).contiguous().realize()
w4 = Tensor.ones(4,4,3,3).contiguous().realize()
w5 = Tensor.ones(4,4,3,3).contiguous().realize()
out = a.conv2d(w1).conv2d(w2).conv2d(w3).conv2d(w4).conv2d(w5).contiguous().realize()
# *** randomness ***
def test_random(self):

View file

@ -90,9 +90,9 @@ class BlockContext:
ctx.block_ctxs[u] = _sort_ctx(this_block_ctx) if u.op is not Ops.SINK else ()
# RANGE/IF add to the next ctx
# STORE/ASSIGN subtract from the next ctx
# STORE/REDUCE_AXIS subtract from the next ctx
if u.op in {Ops.RANGE, Ops.IF}: ctx.child_ctxs[u] = _sort_ctx(ctx.block_ctxs[u] + (u,))
elif u.op is Ops.STORE: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
elif u.op in {Ops.STORE, Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ctx.child_ctxs[u] = tuple([y for y in ctx.block_ctxs[u] if y not in u.src])
return ctx
# ***** make blocks *****

View file

@ -141,7 +141,9 @@ class CStyleLanguage(Renderer):
c: defaultdict[str, int] = defaultdict(int)
name = "test"
for u in uops:
if u.op is Ops.NOOP: continue
if u.op is Ops.NOOP:
if len(u.src): r[u] = r[u.src[0]]
continue
if u.op is Ops.SINK:
if u.arg is not None: name = u.arg.function_name
continue

View file

@ -3,7 +3,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP, argsort
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.schedule.multi import multi_pm
from tinygrad.shape.shapetracker import ShapeTracker
@ -417,6 +417,121 @@ finalize_contiguous = PatternMatcher([
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def contiguous_create_ranges(ctx:list[int], x:UOp):
if len(x.src) != 1: return None
ranges = []
for s in x.shape:
if resolve(s!=1):
ranges.append(UOp.range(dtypes.int, s, ctx[0]))
ctx[0] += 1
else:
ranges.append(UOp.const(dtypes.int, 0))
mm = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
buf = UOp.new_buffer(x.device, prod(x.shape), x.dtype).reshape(x.shape)
mm2 = UOp(Ops.MAP, dtype=x.src[0].dtype, src=(buf,)+tuple(ranges))
return UOp(Ops.STORE, src=(mm2, mm)+tuple(ranges))
#return x.replace(src=(mm,)+tuple(ranges))
def map_reshape(x:UOp):
# don't push on the final buffer reshape for readable graph
#if x.src[0].src[0].op is Ops.BUFFER: return None
acc = 1
to_sum = []
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
to_sum.append(acc*src)
acc *= s
mish = sum(to_sum)
ret = []
for s in x.src[0].src[0].shape[::-1]:
if resolve(s!=1):
ret.append(mish % s)
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
# generic
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
def map_reduce(ctx:list[int], x:UOp):
rngs = list(x.src[1:])
r = x.src[0]
new_ranges = []
for i,s in enumerate(r.src[0].shape):
if i in r.arg[1]:
assert rngs[i].op == Ops.CONST
rngs[i] = UOp.range(dtypes.int, s, ctx[0])
new_ranges.append(rngs[i])
ctx[0] += 1
mm = UOp(Ops.MAP, r.src[0].dtype, src=(r.src[0],)+tuple(rngs))
return UOp(Ops.REDUCE_AXIS, r.dtype, src=(mm,)+tuple(new_ranges), arg=r.arg)
def map_permute(x:UOp):
ret = x.src[1:]
# argsort or not?
perm = argsort(x.src[0].arg)
print(perm, x.src[0].arg)
ret = tuple([ret[p] for p in perm])
print(x.src[0].src[0].shape, ret)
#return UOp(Ops.MAP, dtype=x.src[0].src[0].dtype, src=(x.src[0].src[0],)+ret)
return x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+ret) for s in x.src[0].src]))
def map_expand(x:UOp):
r = x.src[0]
inp_shape, exp_shape = x.src[0].src[0].shape, x.src[0].shape
ret = list(x.src[1:])
exp_ranges = []
for i,(x,y) in enumerate(zip(inp_shape, exp_shape)):
if x != y:
exp_ranges.append(ret[i])
ret[i] = UOp.const(dtypes.int, 0)
mm = UOp(Ops.MAP, r.dtype, src=(r.src[0],)+tuple(ret))
return UOp(Ops.EXPAND, r.dtype, src=(mm,)+tuple(exp_ranges), arg=r.arg)
def map_shrink(ctx:list[int], x:UOp):
r = x.src[0]
ret = list(x.src[1:])
for i,(s,(ss,se)) in enumerate(zip(r.src[0].shape, r.arg)):
assert ss == 0, "add to range?"
if se-ss != s and False:
new_ret_i = [ret[i]]
if ss != 0:
new_ret_i = [UOp.range(dtypes.int, ss, ctx[0])] + new_ret_i
ctx[0] += 1
if se != s:
new_ret_i = new_ret_i + [UOp.range(dtypes.int, s-se, ctx[0])]
ctx[0] += 1
ret[i] = UOp(Ops.CATRANGE, src=tuple(new_ret_i))
mm = UOp(Ops.MAP, r.dtype, src=(r.src[0],)+tuple(ret))
#return mm
# TODO: put the ranges on the shrink?
return UOp(Ops.SHRINK, r.dtype, src=(mm,), arg=r.arg)
index_pushing = PatternMatcher([
(UPat(Ops.CONTIGUOUS, name="x"), contiguous_create_ranges),
(UPat(Ops.MAP, src=(UPat(Ops.RESHAPE),), allow_any_len=True, name="x"), map_reshape),
(UPat(Ops.MAP, src=(UPat(Ops.PERMUTE),), allow_any_len=True, name="x"), map_permute),
(UPat(Ops.MAP, src=(UPat(Ops.EXPAND),), allow_any_len=True, name="x"), map_expand),
(UPat(Ops.MAP, src=(UPat(Ops.SHRINK),), allow_any_len=True, name="x"), map_shrink),
(UPat(Ops.MAP, src=(UPat(Ops.REDUCE_AXIS),), allow_any_len=True, name="x"), map_reduce),
# move MAP through elementwise ALU
(UPat(Ops.MAP, src=(UPat(GroupOp.Elementwise),), allow_any_len=True, name="x"),
lambda x: x.src[0].replace(src=tuple([UOp(Ops.MAP, dtype=s.dtype, src=(s,)+x.src[1:]) for s in x.src[0].src]))),
# MAP on STORE is NOOP
(UPat(Ops.MAP, src=(UPat(Ops.STORE),), allow_any_len=True, name="x"), lambda x: x.src[0]),
])
fix_buffers = PatternMatcher([
(UPat(Ops.BUFFER, name="x"), lambda x: UOp(Ops.DEFINE_GLOBAL, dtype=x.dtype.ptr(x.arg), arg=x.src[0].arg)),
(UPat(Ops.MAP, name="x"), lambda x: x.replace(op=Ops.INDEX, dtype=x.src[0].dtype).load()),
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), name="x", allow_any_len=True), lambda x: x.replace(src=(x.src[0].src[0],)+x.src[1:])),
(UPat((Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE), name="x"), lambda x: x.src[0]),
# do EXPANDs need to track the ranges they end?
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0]),
#(UPat(Ops.EXPAND, name="x"), lambda x: x.replace(arg=None, op=Ops.NOOP)),
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: x.replace(op=Ops.REDUCE, arg=x.arg[0])),
])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
@ -428,6 +543,20 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
Returns:
Map transforming each UOp in the sink to the Ops.KERNEL graph.
"""
pushed = graph_rewrite(sink, index_pushing, ctx=[0], name="index pushing", bottom_up=True)
pushed = graph_rewrite(pushed, fix_buffers, name="fix buffers")
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
pushed = graph_rewrite(pushed, pm_reduce, ctx=ReduceContext(), name="remove reduce")
from tinygrad.codegen.linearize import block_create, BlockContext, pm_blockend_merge, block_merge, pm_finalize
pushed = graph_rewrite(pushed, block_create, ctx=BlockContext.from_sink(pushed), name="block create", bottom_up=True)
pushed = graph_rewrite(pushed, pm_blockend_merge, name="blockend merge")
pushed = graph_rewrite(pushed, block_merge, name="block merge")
pushed = graph_rewrite(pushed, pm_finalize, name="finalize")
from tinygrad.device import Device
try:
print(Device['CPU'].renderer.render(pushed.arg.lst))
except Exception as e:
print("render fail", e)
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")

View file

@ -10,6 +10,7 @@ class FastEnum(IntEnum):
class Ops(FastEnum):
# uops that aren't rendered
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto() # noqa: E702
MAP = auto(); CATRANGE = auto()
# buffer ops
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
@ -82,6 +83,7 @@ class GroupOp:
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}