Compare commits

...

51 commits

Author SHA1 Message Date
George Hotz
5766865193
Merge branch 'master' into no_merge_views 2025-08-15 08:20:46 -07:00
George Hotz
addd19d5e1 cleanups 2025-08-14 19:11:46 -07:00
George Hotz
66b92ffc82 one at a time 2025-08-14 16:43:15 -07:00
George Hotz
35116959ea test mnist passes 2025-08-14 16:24:32 -07:00
George Hotz
ffa08e9c94 rangeify works again 2025-08-14 15:04:41 -07:00
George Hotz
c735855dc0 test_plus works 2025-08-14 13:54:13 -07:00
George Hotz
a5d3b54f47 work 2025-08-14 13:52:14 -07:00
George Hotz
6131c0aad3 work 2025-08-14 13:17:49 -07:00
George Hotz
46caa43733 k splitting 2025-08-14 11:21:56 -07:00
George Hotz
4fd4e13fcf
Merge branch 'master' into no_merge_views 2025-08-14 08:07:52 -07:00
George Hotz
ab4ccf56a7 no 2025-08-14 08:07:06 -07:00
George Hotz
332630ddb5 threefry one kernel 2025-08-13 19:54:09 -07:00
George Hotz
e5eae3f524 assign becomes store 2025-08-13 19:11:59 -07:00
George Hotz
cae3616a68 cleanups 2025-08-13 19:02:53 -07:00
George Hotz
3aa80e7176 rangify bmnist 2025-08-13 18:44:43 -07:00
George Hotz
b1e2fb9afd rangeify in 2025-08-13 18:28:19 -07:00
George Hotz
59bfab8a9b sym 2025-08-13 17:49:54 -07:00
George Hotz
b5d7d339f4 no range arg 2025-08-13 17:43:17 -07:00
George Hotz
b7c195bf7e
Merge branch 'master' into no_merge_views 2025-08-13 16:21:41 -07:00
George Hotz
8592fba874
Merge branch 'master' into no_merge_views 2025-08-13 12:46:52 -07:00
George Hotz
10ffd7e17b simpler 2025-08-13 11:42:42 -07:00
George Hotz
5489be812c random works 2025-08-13 09:47:02 -07:00
George Hotz
e3d8185ba4 ignore that 2025-08-13 09:43:08 -07:00
George Hotz
cbf85fbfd0
Merge branch 'master' into no_merge_views 2025-08-13 09:40:44 -07:00
George Hotz
11d65cb002 test_gemm works 2025-08-11 18:58:40 -07:00
George Hotz
5f0816ef69 simpler 2025-08-11 18:41:32 -07:00
George Hotz
b8b28e1135
Merge branch 'master' into no_merge_views 2025-08-11 18:29:15 -07:00
George Hotz
9d46bc2939 endrange 2025-08-11 17:19:02 -07:00
George Hotz
2b7957e765 map_expand 2025-08-11 15:22:56 -07:00
George Hotz
4102e46370 cache has issues 2025-08-11 14:35:23 -07:00
George Hotz
6da4784c66
Merge branch 'master' into no_merge_views 2025-08-11 13:23:14 -07:00
George Hotz
2feeb8c8a6 cleanups 2025-08-11 08:57:30 -07:00
George Hotz
04fa825a26 careful w the cache 2025-08-10 15:52:41 -07:00
George Hotz
706188ad16 bugfix 2025-08-10 15:48:07 -07:00
George Hotz
fbe9909d90 update for master 2025-08-10 15:43:27 -07:00
George Hotz
16d2d9daac
Merge branch 'master' into no_merge_views 2025-08-10 15:39:37 -07:00
George Hotz
48ca6d888d was dumb 2025-08-10 14:36:35 -07:00
George Hotz
b7ea16f161 localish fa 2025-08-10 14:30:25 -07:00
George Hotz
cc34518a52 RewriteNotReady 2025-08-10 13:56:38 -07:00
George Hotz
76a97e04b0 cleanups 2025-08-10 12:24:16 -07:00
George Hotz
0d64aa1f1e this does work...but with a global 2025-08-10 12:00:55 -07:00
George Hotz
9979730f3f children stuff that doesn't work 2025-08-10 10:57:54 -07:00
George Hotz
7ddcb8632f simpler 2025-08-09 08:14:02 -07:00
George Hotz
e268eb2d5c tform ffn 2025-08-08 18:29:48 -07:00
George Hotz
ee06481036 ranges 2025-08-08 18:18:27 -07:00
George Hotz
38c9b5ed2c conv hack 2025-08-08 14:36:48 -07:00
George Hotz
7249a711c2 half contig 2025-08-08 08:43:41 -07:00
George Hotz
efdf08f3e2 global rangeify 2025-08-07 15:12:00 -07:00
George Hotz
9a2f55425b global rangeify 2025-08-07 15:08:45 -07:00
George Hotz
9ed409d0f8
Merge branch 'master' into no_merge_views 2025-08-07 14:42:07 -07:00
George Hotz
b8791e962c don't merge views, mops in kernel 2025-08-06 17:23:31 -07:00
10 changed files with 506 additions and 26 deletions

107
test/test_rangeify.py Normal file
View file

@ -0,0 +1,107 @@
import unittest
from tinygrad import Tensor
class TestRangeify(unittest.TestCase):
def test_double_gemm(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(A@B@C).realize()
def test_double_gemm_exp(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).exp()@C).exp()).realize()
def test_double_gemm_relu(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).relu()@C).relu()).realize()
def test_double_gemm_relu_half_contig(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).relu().contiguous(arg=(1,))@C).relu()).realize()
def test_double_gemm_half_contig(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
((A@B).contiguous(arg=(1,))@C).realize()
def test_double_gemm_contig(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
((A@B).contiguous()@C).realize()
def test_many_gemm(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
D = Tensor.empty(N, N)
E = Tensor.empty(N, N)
F = Tensor.empty(N, N)
(A@B@C@D@E@F).realize()
def test_conv2d(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
x.conv2d(w1).realize()
def test_conv2d_t(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
(x*2).conv2d(w1).realize()
def test_double_conv2d(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).conv2d(w2).realize()
def test_double_conv2d_half_contig(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
# NOTE: this contiguous doesn't help
x.conv2d(w1).contiguous(arg=(1,)).conv2d(w2).permute(0,2,3,1).contiguous().realize()
def test_double_conv2d_contig(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).contiguous().conv2d(w2).realize()
def test_transformer_ffn(self):
from tinygrad.apps.llm import TransformerBlock
from tinygrad import nn
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)
for p in nn.state.get_parameters(blk): p.replace(Tensor.empty(p.shape))
x = Tensor.empty(128, 1024)
out = blk._feed_forward(x)
out.realize()
def test_flash_attention(self):
BS = 4
HEADS = 2
MATDIM = 16
EMB = 8
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
q.scaled_dot_product_attention(k, v).realize()
if __name__ == '__main__':
unittest.main()

View file

@ -15,7 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
from tinygrad.codegen.opt.swizzler import merge_views
from tinygrad.schedule.kernelize import get_kernelize_map, Kernel
from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
@ -1745,6 +1746,7 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(xt, 1)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [-1, 2]])
@unittest.skip("a")
def test_advanced_indexing(self):
X = Tensor.arange(10)+1
xt = X[[0, -1]]

View file

@ -30,7 +30,7 @@ class TestTiny(unittest.TestCase):
def test_gemm(self, N=64, out_dtype=dtypes.float):
a = Tensor.ones(N,N).contiguous()
b = Tensor.eye(N).contiguous()
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
self.assertListEqual((out:=a@b).contiguous().flatten().tolist(), [1.0]*(N*N))
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
# *** randomness ***
@ -103,7 +103,7 @@ class TestTiny(unittest.TestCase):
Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
# run model inference
probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
probs = Tensor.empty(1, 1, 28, 28).sequential(layers).tolist()
self.assertEqual(len(probs[0]), 10)
# *** image ***

View file

@ -227,7 +227,7 @@ block_merge = PatternMatcher([
def finalize(sink:UOp) -> UOp:
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
raise RuntimeError("linearize failure")
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
# place the early things
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)

View file

@ -104,7 +104,8 @@ class CStyleLanguage(Renderer):
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})"}
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})",
Ops.THREEFRY: lambda a,b,dtype: f"threefry({a},{b})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"}
string_rewrite = base_rewrite
extra_matcher = extra_pm

View file

@ -1,18 +1,26 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
from tinygrad.uop.ops import track_rewrites, _substitute
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_uop_spec
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
from tinygrad.uop.symbolic import symbolic_simple, sym
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP, Timing
from tinygrad.dtype import ImageDType
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
from tinygrad.schedule.rangeify import pm_rangeify, RangeifyContext, ChildrenContext, pm_add_buffers, AddBufferContext, rangeify_fixups, pm_children
from tinygrad.codegen.opt.swizzler import apply_swizzle, swizzle_reduceop
# creation can recurse a lot
import sys
sys.setrecursionlimit(10000)
mops_merge = PatternMatcher([
# RESHAPE on RESHAPE is the second reshape
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
# non shape changing RESHAPE is NOOP
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
])
# **** schedule simplifier
def simplify_stride0_reduce(reduce:UOp, x:UOp):
@ -48,7 +56,7 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
sym = symbolic_simple+PatternMatcher([
kernelize_sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
@ -190,8 +198,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
bufs.append(s)
# replace global memory ops with the BUFFER they write to
# NOTE: merge_views is needed to unbind the reshapes
ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
ast = graph_rewrite(k.arg.ast, mops_merge+replace_buffers, bufs, bottom_up=True, name="replace buffers")
if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
return k.replace(arg=Kernel(ast, k.arg.metadata))
@ -314,6 +321,68 @@ finalize_contiguous = PatternMatcher([
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
new_fixups = mops_merge+PatternMatcher([
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
# TODO: this should be BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
])
# *** store splitting
@dataclass
class LocalAddBufferContext:
dg:int = 0
map:dict = field(default_factory=dict)
def debuf(ctx:LocalAddBufferContext, b:UOp): return UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx.map[b][1])
def split_load(ctx:LocalAddBufferContext, s:UOp):
b = s.src[0].src[0]
if b.op is Ops.BUFFER:
if len(s.src) == 1:
lb = b
else:
assert len(s.src) == 2
lb = s.src[1]
assert b not in ctx.map or ctx.map[b][0] == lb
if b not in ctx.map:
ctx.map[b] = (lb, ctx.dg)
ctx.dg += 1
return s.replace(src=s.src[0:1]) if len(s.src) > 1 else None
def handle_store(ctx:LocalAddBufferContext, s:UOp):
b = s.src[0].src[0]
if b.op is Ops.BUFFER:
if b not in ctx.map:
ctx.map[b] = (b, ctx.dg)
ctx.dg += 1
if s.src[1].op is not Ops.COPY: return None
return s.src[1]
do_debuf = PatternMatcher([
(UPat(Ops.BUFFER, name="b"), debuf),
(UPat(Ops.COPY, name="c"), lambda c: c.src[0]),
])
to_define_global = PatternMatcher([
(UPat(Ops.BUFFER, name="b"), debuf),
(UPat(Ops.LOAD, name="s"), split_load),
(UPat(Ops.STORE, name="s"), handle_store),
])
def split_store(x:UOp):
shape = tuple([r.vmax+1 for r in x.src[2:]])
name = "k_"+'_'.join([str(s) for s in shape])
ctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global, ctx=ctx, name="* kernel split", bottom_up=True)
ret = ret.sink(arg=KernelInfo(name=name)) if ret.op is Ops.STORE else ret
kernel = UOp(Ops.KERNEL, src=tuple([x[0] for x in ctx.map.values()]), arg=Kernel(ret, ()))
return kernel.src[0].assign(kernel)
split_kernels = PatternMatcher([
(UPat(Ops.STORE, name="x"), split_store)
])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
@ -325,12 +394,48 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
Returns:
Map transforming each UOp in the sink to the Ops.KERNEL graph.
"""
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
# testing
# NOTE: graph_rewrite_map with bottom_up is broken
with Timing("*** rangeify in "):
#tensor_map = graph_rewrite_map(tensor_map[sink], remove_tags, bottom_up=True, input_map=tensor_map, name="* remove tags")
forced_contig = [x.base for x in tensor_map[sink].src]
#for u in tensor_map[sink].toposort():
# if u.op is Ops.COPY: forced_contig.append(u)
tensor_map = graph_rewrite_map(tensor_map[sink], rangeify_fixups, bottom_up=True, ctx=forced_contig, input_map=tensor_map, name="* contiguous")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="* children")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="* rangeify")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, input_map=tensor_map, name="* buffer")
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="* split kernels")
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
return tensor_map
"""
rsink = tensor_map[sink]
rsink = graph_rewrite(rsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="* rangeify")
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, name="* buffer")
rsink = graph_rewrite(rsink, do_debuf, ctx=[], name="* debuf")
"""
#if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
#rsink = graph_rewrite(rsink, sym, name="* symbolic")
#from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
#rsink = graph_rewrite(rsink, pm_reduce, ctx=ReduceContext(), name="* remove reduce")
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
rsink = apply_rewrites(rsink, rewrites_for_linearizer)
from tinygrad.renderer.cstyle import CStyleLanguage
src = CStyleLanguage().render(rsink.arg.lst)
print(src)
return {}
#return tensor_map
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
# insert contiguous in places determined by the realize map
realize_map = group_realizes(tensor_map[sink])

View file

@ -0,0 +1,256 @@
from typing import Any
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady
from tinygrad.helpers import argsort, prod, all_same
rangeify_fixups = PatternMatcher([
(UPat(GroupOp.All, name="x"), lambda ctx,x: x.replace(tag=69).contiguous(tag=2).reshape(x.shape) if x in ctx and x.tag != 69 else None),
# all contiguous on COPY
#(UPat(Ops.COPY, name="x"), lambda x: x.replace(tag=69).contiguous(tag=2).reshape(x.shape) if x.tag != 69 else None),
# double contiguous merge
(UPat(Ops.CONTIGUOUS, name="c2", src=(UPat(Ops.CONTIGUOUS, name="c1"))),
lambda c1,c2: c1.replace(tag=2 if c2.tag == 2 or c1.tag == 2 else None) if c1.arg is None and c2.arg is None else None),
# const
#(UPat(Ops.CONST, name="x"), lambda x:
# x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
# len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
])
@dataclass
class ChildrenContext:
children: dict[UOp, list[UOp]]|None = None
def extract_children(ctx:ChildrenContext, x:UOp):
if ctx.children is not None: return
# REDUCE_AXIS is fine here, should go to contig only (gate)
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1 and any(x.op is Ops.REDUCE_AXIS for x in k.toposort())}
def mark_children(ctx:ChildrenContext, x:UOp):
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
return x.replace(src=tuple(new_srcs))
pm_children = PatternMatcher([
(UPat(Ops.SINK, name="x"), extract_children),
(UPat(GroupOp.All-{Ops.CHILD}, name="x"), mark_children),
# hack for one kernel threefry
#(UPat(Ops.CHILD, src=(UPat(Ops.THREEFRY, name="x"),)), lambda x: x),
])
@dataclass
class RangeifyContext:
idx: int = 0
regs: int = 0
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
seen_child: dict[UOp, Any] = field(default_factory=dict)
is_sink_contig: tuple[UOp, ...] = ()
def map_reshape(x:UOp, r:UOp):
acc = 1
to_sum = []
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
to_sum.append(acc*src)
acc *= s
mish = sum(to_sum)
ret = []
for s in r.src[0].shape[::-1]:
if resolve(s!=1):
# this MOD should limit any ranges outside s
ret.append(mish % s)
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
return r.src[0].index(*ret, dtype=x.dtype)
def map_pad(x:UOp, r:UOp):
ret = list(x.src[1:])
bigwhere = UOp.const(dtypes.bool, True)
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
if s == 0 and e == 0: continue
where = UOp.const(dtypes.bool, True)
if e > 0: where = where & (ret[i] < (sh-e))
if s > 0: where = where & (ret[i] >= s)
bigwhere = bigwhere & where
# this is safe but dumb
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
# mask the load
#ret[i] = where.where(ret[i], UOp(Ops.INVALID, dtype=ret[i].dtype))
# PAD is with 0
return bigwhere.simplify().where(UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret)), UOp.const(r.dtype, 0))
def map_expand(r:UOp, x:UOp):
new_rngs = []
ending_ranges = []
non_ending_ranges = []
for a,x,y in zip(x.src[1:], r.src[0].shape, r.shape):
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
if resolve(x!=y, False):
ending_ranges.extend(axis_to_range)
new_rngs.append(a.const_like(0))
else:
non_ending_ranges.extend(axis_to_range)
new_rngs.append(a)
ending_ranges = [x for x in ending_ranges if x not in non_ending_ranges]
ret = r.src[0]
ret = UOp(Ops.ENDRANGE, dtype=ret.dtype, src=(ret,)+tuple(ending_ranges)) if len(ending_ranges) else ret
return ret.index(*new_rngs)
pm_mops = PatternMatcher([
# this is like the definitions of these
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)], dtype=x.dtype)),
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="x"),
lambda r,x: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)])),
# expand needs to end ranges
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="x"), map_expand),
# reshape does a lot of symbolic stuff
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
# pad adds min and max
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
])
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
if x.tag == 1: return None
ranges = []
new_ranges = []
passthrough_idx = []
for i,s in enumerate(x.shape):
if x.arg is not None and i not in x.arg:
assert idx is not None, "partial contig requires index"
ranges.append(idx.src[1+i])
continue
if idx is not None: passthrough_idx.append(idx.src[1+i])
if resolve(s!=1):
ranges.append(UOp.range(dtypes.int, s, ctx.idx))
new_ranges.append(ranges[-1])
ctx.idx += 1
else:
ranges.append(UOp.const(dtypes.int, 0))
ret = x.src[0].index(*ranges).pcontiguous(*new_ranges, arg=x.arg)
# if there's no open ranges, set arg to None so this uses a DEFINE_GLOBAL
if len(ret.ranges) == 0: ret = ret.replace(arg=None)
ret = ret.index(*passthrough_idx) if len(passthrough_idx) else ret
return ret
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
# TODO: this should be in the cache
#print(f"reduce {id(red)}")
rngs = list(idx.src[1:])
new_ranges = []
for i,s in enumerate(red.src[0].shape):
if i in red.arg[1]:
rngs[i] = UOp.range(dtypes.int, s, ctx.idx)
ctx.idx += 1
new_ranges.append(rngs[i])
return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0])
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
#print(f"visit CHILD {x.arg} bottom up")
if c not in ctx.seen_children: ctx.seen_children[c] = {}
ctx.seen_children[c][x.arg[0]] = idx
# wait here until we have seen all the children
if len(ctx.seen_children[c]) != x.arg[1]: raise RewriteNotReady
if c not in ctx.seen_child:
all_rngs = zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()])
out_rngs = []
end_ranges = []
idx_ranges = []
for i,r in enumerate(all_rngs):
if all_same(r):
out_rngs.append(r[0])
else:
out_rngs.append(UOp.range(dtypes.int, c.shape[i], ctx.idx))
ctx.idx += 1
end_ranges.append(out_rngs[-1])
idx_ranges.append(i)
ctx.seen_child[c] = (idx_ranges, end_ranges)
else:
out_rngs = list(idx.src[1:])
idx_ranges, end_ranges = ctx.seen_child[c]
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
if len(idx_ranges) == 0: return c.index(*out_rngs)
return c.index(*out_rngs).pcontiguous(*end_ranges, arg=tuple(idx_ranges)).index(*[idx.src[1+i] for i in idx_ranges])
def indexed_endrange(er:UOp, idx:UOp):
ended = er.src[1:]
earliest_ending_axis = min([x.arg for x in ended])
to_end_axis = []
for i,a in enumerate(idx.src[1:]):
if any(x.arg > earliest_ending_axis for x in a.toposort() if x.op is Ops.RANGE):
to_end_axis.append(i)
if to_end_axis: return idx.replace(src=(er.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:])
return idx.replace(src=(er.src[0],)+idx.src[1:])
pm_rangeify = pm_mops+PatternMatcher([
# if there are new ended children, tag the SINK
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
# if there's an INDEX it can support partial contig
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, name="x"),), allow_any_len=True, name="idx"), map_contiguous),
# sink contigs to kick it off
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: map_contiguous(ctx, x).reshape(x.shape) if x.tag == 2 else None),
# handle ENDRANGE on movement
(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Movement),), allow_any_len=True, name="er"),
lambda er: er.src[0].replace(src=(UOp(Ops.ENDRANGE, dtype=er.dtype, src=(er.src[0].src[0],)+er.src[1:]),))),
# handle ENDRANGE on BUFFER
# and CHILD: python3 test/test_schedule.py TestSchedule.test_cache_reduce_parent
(UPat(Ops.ENDRANGE, src=(UPat((Ops.BUFFER, Ops.CONST, Ops.CONTIGUOUS, Ops.CHILD)),), allow_any_len=True, name="er"), lambda er: er.src[0]),
# handle INDEXed ENDRANGE
(UPat(Ops.INDEX, src=(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="er"),),
allow_any_len=True, name="idx"), indexed_endrange),
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
# CONTIGUOUS on ASSIGN is STORE
# TODO: tag in UPat?
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c", allow_any_len=True),
lambda c,a: UOp(Ops.STORE, src=a.src+c.src[1:]) if c.tag == 1 else None),
])
@dataclass
class AddBufferContext:
dg:int = 0
map:dict = field(default_factory=dict)
def add_store(ctx:AddBufferContext, x:UOp):
rngs = x.src[1:]
shape = tuple([r.vmax+1 for r in rngs])
assert prod(shape) > 0, f"no zero sized buffers {shape}"
if x.arg is None or prod(shape) > 65536:
buf = UOp.new_buffer(x.device, prod(shape), x.dtype)
else:
buf = UOp(Ops.DEFINE_LOCAL, dtype=x.dtype.ptr(size=prod(shape), addrspace=AddrSpace.LOCAL), arg=ctx.dg)
ctx.map[buf] = (buf.op, ctx.dg)
ctx.dg += 1
return buf.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0], *rngs)
def add_load(ctx:AddBufferContext, x:UOp, b:UOp, idx:UOp):
if isinstance(x.dtype, PtrDType): return None
return x.replace(dtype=x.dtype.ptr(b.size)).load()
def add_load_on_store(ctx:AddBufferContext, x:UOp, st:UOp):
rngs = x.src[1:]
shape = tuple([r.vmax+1 for r in rngs])
b = st.src[0].src[0]
assert b.op is Ops.BUFFER
return b.shrink(((0,prod(shape)),)).reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=b.size)).load(st)
pm_add_buffers = pm_mops+PatternMatcher([
(UPat(Ops.PCONTIGUOUS, name="x"), add_store),
(UPat(Ops.ENDRANGE, name="x"), lambda x: x.src[0]),
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
(UPat(Ops.BIND, name="b"), lambda b: b.src[0]),
# CONST can't have axes. remove srcs when we idx
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),
# HACK: consts shouldn't have srcs by here
(UPat(Ops.CONST, name="x"), lambda x: x.replace(src=()) if len(x.src) else None),
])

View file

@ -19,6 +19,7 @@ class Ops(FastEnum):
# ops that adjust the behavior of the scheduler
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
PCONTIGUOUS = auto()
# blocks in linearizer (only used there)
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702

View file

@ -136,10 +136,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def st(self) -> ShapeTracker|None:
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
if self.op is Ops.INDEX and self.src[0].op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER}: return None
if self.op in GroupOp.Block: return None
from tinygrad.shape.shapetracker import ShapeTracker
# VIEW and MovementOps define a new ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg)
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
# CONST with a DEVICE has a shape of ()
if self.op is Ops.CONST and len(self.src) and self.src[0].op is Ops.DEVICE: return ShapeTracker.from_shape(())
@ -158,7 +160,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
# otherwise we get the shape from sources
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
if not (src_sts := [x.st for x in self.src if x.st is not None and x.op is not Ops.INDEX]): return None
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
match self.op:
case Ops.MULTI: shape = tuple(self.src[0].shape[a]*len(self.device) if a == self.axis else s for a,s in enumerate(self.src[0].shape))
@ -186,7 +188,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def ranges(self) -> dict[UOp, None]:
if self.op is Ops.RANGE: return {self:None}
if self.op in {Ops.CONTIGUOUS, Ops.REDUCE, Ops.STORE}:
if self.op in {Ops.PCONTIGUOUS, Ops.REDUCE, Ops.STORE}:
ret = self.src[0].ranges.copy()
for s in self.src[1:]:
if s in ret: del ret[s]
@ -232,7 +234,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None): return UOp(Ops.INDEX, self.dtype, (self,)+tuple([x for x in srcs if x is not None]))
def index(self, *srcs:UOp|None, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): return self.index(idx)
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
@ -269,11 +272,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
# TODO: clean this all up with rangeify
if shape is not None:
from tinygrad.shape.shapetracker import ShapeTracker
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
if device is not None:
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
if shape is not None:
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
else:
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
return ret
@staticmethod
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
@ -291,6 +298,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
def pcontiguous(self, *args, **kwargs): return UOp(Ops.PCONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def fuse(self): return self.alu(Ops.FUSE)
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
@ -359,7 +367,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def _mop(self, op:Ops, arg) -> UOp:
ret = UOp(op, self.dtype, (self,), arg)
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
if self.st is not None and self.st == ret.st: return self # ignore NOOPs, also check ret.st
return ret
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
@ -935,7 +943,7 @@ class RewriteContext:
else:
# in stage 2, we link the result of new_n to the result of n
try: self.replace[n] = self.replace[new_n]
except KeyError: raise RuntimeError("infinite loop in graph_rewrite (explicit)") # pylint: disable=raise-missing-from
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
except RewriteNotReady:
# retry this later
stack.insert(0, (n, stage, new_n))
@ -951,7 +959,7 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
new_map: dict[UOp, UOp] = {}
for k in sink.toposort():
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
if input_map is not None:

View file

@ -19,7 +19,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.CHILD: "#80fff0"}
Ops.PCONTIGUOUS: "#FFC18D", Ops.CHILD: "#80fff0"}
# VIZ API