mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
51 commits
master
...
no_merge_v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5766865193 |
||
|
|
addd19d5e1 | ||
|
|
66b92ffc82 | ||
|
|
35116959ea | ||
|
|
ffa08e9c94 | ||
|
|
c735855dc0 | ||
|
|
a5d3b54f47 | ||
|
|
6131c0aad3 | ||
|
|
46caa43733 | ||
|
|
4fd4e13fcf |
||
|
|
ab4ccf56a7 | ||
|
|
332630ddb5 | ||
|
|
e5eae3f524 | ||
|
|
cae3616a68 | ||
|
|
3aa80e7176 | ||
|
|
b1e2fb9afd | ||
|
|
59bfab8a9b | ||
|
|
b5d7d339f4 | ||
|
|
b7c195bf7e |
||
|
|
8592fba874 |
||
|
|
10ffd7e17b | ||
|
|
5489be812c | ||
|
|
e3d8185ba4 | ||
|
|
cbf85fbfd0 |
||
|
|
11d65cb002 | ||
|
|
5f0816ef69 | ||
|
|
b8b28e1135 |
||
|
|
9d46bc2939 | ||
|
|
2b7957e765 | ||
|
|
4102e46370 | ||
|
|
6da4784c66 |
||
|
|
2feeb8c8a6 | ||
|
|
04fa825a26 | ||
|
|
706188ad16 | ||
|
|
fbe9909d90 | ||
|
|
16d2d9daac |
||
|
|
48ca6d888d | ||
|
|
b7ea16f161 | ||
|
|
cc34518a52 | ||
|
|
76a97e04b0 | ||
|
|
0d64aa1f1e | ||
|
|
9979730f3f | ||
|
|
7ddcb8632f | ||
|
|
e268eb2d5c | ||
|
|
ee06481036 | ||
|
|
38c9b5ed2c | ||
|
|
7249a711c2 | ||
|
|
efdf08f3e2 | ||
|
|
9a2f55425b | ||
|
|
9ed409d0f8 |
||
|
|
b8791e962c |
10 changed files with 506 additions and 26 deletions
107
test/test_rangeify.py
Normal file
107
test/test_rangeify.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor
|
||||
|
||||
class TestRangeify(unittest.TestCase):
|
||||
def test_double_gemm(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(A@B@C).realize()
|
||||
|
||||
def test_double_gemm_exp(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(((A@B).exp()@C).exp()).realize()
|
||||
|
||||
def test_double_gemm_relu(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(((A@B).relu()@C).relu()).realize()
|
||||
|
||||
def test_double_gemm_relu_half_contig(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(((A@B).relu().contiguous(arg=(1,))@C).relu()).realize()
|
||||
|
||||
def test_double_gemm_half_contig(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
((A@B).contiguous(arg=(1,))@C).realize()
|
||||
|
||||
def test_double_gemm_contig(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
((A@B).contiguous()@C).realize()
|
||||
|
||||
def test_many_gemm(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
D = Tensor.empty(N, N)
|
||||
E = Tensor.empty(N, N)
|
||||
F = Tensor.empty(N, N)
|
||||
(A@B@C@D@E@F).realize()
|
||||
|
||||
def test_conv2d(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
x.conv2d(w1).realize()
|
||||
|
||||
def test_conv2d_t(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
(x*2).conv2d(w1).realize()
|
||||
|
||||
def test_double_conv2d(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).conv2d(w2).realize()
|
||||
|
||||
def test_double_conv2d_half_contig(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
# NOTE: this contiguous doesn't help
|
||||
x.conv2d(w1).contiguous(arg=(1,)).conv2d(w2).permute(0,2,3,1).contiguous().realize()
|
||||
|
||||
def test_double_conv2d_contig(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).contiguous().conv2d(w2).realize()
|
||||
|
||||
def test_transformer_ffn(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
from tinygrad import nn
|
||||
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)
|
||||
for p in nn.state.get_parameters(blk): p.replace(Tensor.empty(p.shape))
|
||||
|
||||
x = Tensor.empty(128, 1024)
|
||||
out = blk._feed_forward(x)
|
||||
out.realize()
|
||||
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
HEADS = 2
|
||||
MATDIM = 16
|
||||
EMB = 8
|
||||
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -15,7 +15,8 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
|||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
|
||||
from tinygrad.codegen.opt.swizzler import merge_views
|
||||
from tinygrad.schedule.kernelize import get_kernelize_map, Kernel
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
|
||||
|
|
@ -1745,6 +1746,7 @@ class TestIndexing(unittest.TestCase):
|
|||
self.check_schedule(xt, 1)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [-1, 2]])
|
||||
|
||||
@unittest.skip("a")
|
||||
def test_advanced_indexing(self):
|
||||
X = Tensor.arange(10)+1
|
||||
xt = X[[0, -1]]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestTiny(unittest.TestCase):
|
|||
def test_gemm(self, N=64, out_dtype=dtypes.float):
|
||||
a = Tensor.ones(N,N).contiguous()
|
||||
b = Tensor.eye(N).contiguous()
|
||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
self.assertListEqual((out:=a@b).contiguous().flatten().tolist(), [1.0]*(N*N))
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
# *** randomness ***
|
||||
|
|
@ -103,7 +103,7 @@ class TestTiny(unittest.TestCase):
|
|||
Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
|
||||
|
||||
# run model inference
|
||||
probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
|
||||
probs = Tensor.empty(1, 1, 28, 28).sequential(layers).tolist()
|
||||
self.assertEqual(len(probs[0]), 10)
|
||||
|
||||
# *** image ***
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ block_merge = PatternMatcher([
|
|||
|
||||
def finalize(sink:UOp) -> UOp:
|
||||
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
||||
raise RuntimeError("linearize failure")
|
||||
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
|
||||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
|
|
|
|||
|
|
@ -104,7 +104,8 @@ class CStyleLanguage(Renderer):
|
|||
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
||||
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
|
||||
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})"}
|
||||
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})",
|
||||
Ops.THREEFRY: lambda a,b,dtype: f"threefry({a},{b})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"}
|
||||
|
||||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
|
|
|
|||
|
|
@ -1,18 +1,26 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.uop.symbolic import symbolic_simple, sym
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, prod, dedup, unwrap, getenv, pluralize, FUSE_ARANGE, DEBUG, SPLIT_REDUCEOP, Timing
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
||||
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
||||
from tinygrad.schedule.rangeify import pm_rangeify, RangeifyContext, ChildrenContext, pm_add_buffers, AddBufferContext, rangeify_fixups, pm_children
|
||||
from tinygrad.codegen.opt.swizzler import apply_swizzle, swizzle_reduceop
|
||||
|
||||
# creation can recurse a lot
|
||||
import sys
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
mops_merge = PatternMatcher([
|
||||
# RESHAPE on RESHAPE is the second reshape
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
|
||||
# non shape changing RESHAPE is NOOP
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
||||
])
|
||||
|
||||
# **** schedule simplifier
|
||||
|
||||
def simplify_stride0_reduce(reduce:UOp, x:UOp):
|
||||
|
|
@ -48,7 +56,7 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
|||
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
||||
return base.copy_to_device(copy.device).view(view.arg)
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
kernelize_sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
|
|
@ -190,8 +198,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
|||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||
bufs.append(s)
|
||||
# replace global memory ops with the BUFFER they write to
|
||||
# NOTE: merge_views is needed to unbind the reshapes
|
||||
ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
|
||||
ast = graph_rewrite(k.arg.ast, mops_merge+replace_buffers, bufs, bottom_up=True, name="replace buffers")
|
||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
|
@ -314,6 +321,68 @@ finalize_contiguous = PatternMatcher([
|
|||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
new_fixups = mops_merge+PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
|
||||
# TODO: this should be BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
|
||||
])
|
||||
|
||||
# *** store splitting
|
||||
|
||||
@dataclass
|
||||
class LocalAddBufferContext:
|
||||
dg:int = 0
|
||||
map:dict = field(default_factory=dict)
|
||||
|
||||
def debuf(ctx:LocalAddBufferContext, b:UOp): return UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx.map[b][1])
|
||||
|
||||
def split_load(ctx:LocalAddBufferContext, s:UOp):
|
||||
b = s.src[0].src[0]
|
||||
if b.op is Ops.BUFFER:
|
||||
if len(s.src) == 1:
|
||||
lb = b
|
||||
else:
|
||||
assert len(s.src) == 2
|
||||
lb = s.src[1]
|
||||
assert b not in ctx.map or ctx.map[b][0] == lb
|
||||
if b not in ctx.map:
|
||||
ctx.map[b] = (lb, ctx.dg)
|
||||
ctx.dg += 1
|
||||
return s.replace(src=s.src[0:1]) if len(s.src) > 1 else None
|
||||
|
||||
def handle_store(ctx:LocalAddBufferContext, s:UOp):
|
||||
b = s.src[0].src[0]
|
||||
if b.op is Ops.BUFFER:
|
||||
if b not in ctx.map:
|
||||
ctx.map[b] = (b, ctx.dg)
|
||||
ctx.dg += 1
|
||||
if s.src[1].op is not Ops.COPY: return None
|
||||
return s.src[1]
|
||||
|
||||
do_debuf = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="b"), debuf),
|
||||
(UPat(Ops.COPY, name="c"), lambda c: c.src[0]),
|
||||
])
|
||||
|
||||
to_define_global = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="b"), debuf),
|
||||
(UPat(Ops.LOAD, name="s"), split_load),
|
||||
(UPat(Ops.STORE, name="s"), handle_store),
|
||||
])
|
||||
|
||||
def split_store(x:UOp):
|
||||
shape = tuple([r.vmax+1 for r in x.src[2:]])
|
||||
name = "k_"+'_'.join([str(s) for s in shape])
|
||||
ctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global, ctx=ctx, name="* kernel split", bottom_up=True)
|
||||
ret = ret.sink(arg=KernelInfo(name=name)) if ret.op is Ops.STORE else ret
|
||||
kernel = UOp(Ops.KERNEL, src=tuple([x[0] for x in ctx.map.values()]), arg=Kernel(ret, ()))
|
||||
return kernel.src[0].assign(kernel)
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="x"), split_store)
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
"""
|
||||
|
|
@ -325,12 +394,48 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
Returns:
|
||||
Map transforming each UOp in the sink to the Ops.KERNEL graph.
|
||||
"""
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
# testing
|
||||
# NOTE: graph_rewrite_map with bottom_up is broken
|
||||
with Timing("*** rangeify in "):
|
||||
#tensor_map = graph_rewrite_map(tensor_map[sink], remove_tags, bottom_up=True, input_map=tensor_map, name="* remove tags")
|
||||
forced_contig = [x.base for x in tensor_map[sink].src]
|
||||
#for u in tensor_map[sink].toposort():
|
||||
# if u.op is Ops.COPY: forced_contig.append(u)
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], rangeify_fixups, bottom_up=True, ctx=forced_contig, input_map=tensor_map, name="* contiguous")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="* children")
|
||||
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="* rangeify")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, input_map=tensor_map, name="* buffer")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="* split kernels")
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
||||
return tensor_map
|
||||
|
||||
"""
|
||||
rsink = tensor_map[sink]
|
||||
rsink = graph_rewrite(rsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="* rangeify")
|
||||
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=AddBufferContext(), bottom_up=True, name="* buffer")
|
||||
rsink = graph_rewrite(rsink, do_debuf, ctx=[], name="* debuf")
|
||||
"""
|
||||
#if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
|
||||
#rsink = graph_rewrite(rsink, sym, name="* symbolic")
|
||||
|
||||
#from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
|
||||
#rsink = graph_rewrite(rsink, pm_reduce, ctx=ReduceContext(), name="* remove reduce")
|
||||
|
||||
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
|
||||
rsink = apply_rewrites(rsink, rewrites_for_linearizer)
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
src = CStyleLanguage().render(rsink.arg.lst)
|
||||
print(src)
|
||||
return {}
|
||||
#return tensor_map
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
||||
# insert contiguous in places determined by the realize map
|
||||
realize_map = group_realizes(tensor_map[sink])
|
||||
|
|
|
|||
256
tinygrad/schedule/rangeify.py
Normal file
256
tinygrad/schedule/rangeify.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady
|
||||
from tinygrad.helpers import argsort, prod, all_same
|
||||
|
||||
rangeify_fixups = PatternMatcher([
|
||||
(UPat(GroupOp.All, name="x"), lambda ctx,x: x.replace(tag=69).contiguous(tag=2).reshape(x.shape) if x in ctx and x.tag != 69 else None),
|
||||
# all contiguous on COPY
|
||||
#(UPat(Ops.COPY, name="x"), lambda x: x.replace(tag=69).contiguous(tag=2).reshape(x.shape) if x.tag != 69 else None),
|
||||
# double contiguous merge
|
||||
(UPat(Ops.CONTIGUOUS, name="c2", src=(UPat(Ops.CONTIGUOUS, name="c1"))),
|
||||
lambda c1,c2: c1.replace(tag=2 if c2.tag == 2 or c1.tag == 2 else None) if c1.arg is None and c2.arg is None else None),
|
||||
# const
|
||||
#(UPat(Ops.CONST, name="x"), lambda x:
|
||||
# x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
|
||||
# len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class ChildrenContext:
|
||||
children: dict[UOp, list[UOp]]|None = None
|
||||
|
||||
def extract_children(ctx:ChildrenContext, x:UOp):
|
||||
if ctx.children is not None: return
|
||||
# REDUCE_AXIS is fine here, should go to contig only (gate)
|
||||
ctx.children = {k:list(v.keys()) for k,v in x.get_children_map().items() if len(v) > 1 and any(x.op is Ops.REDUCE_AXIS for x in k.toposort())}
|
||||
def mark_children(ctx:ChildrenContext, x:UOp):
|
||||
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(s,), arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
|
||||
return x.replace(src=tuple(new_srcs))
|
||||
pm_children = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), extract_children),
|
||||
(UPat(GroupOp.All-{Ops.CHILD}, name="x"), mark_children),
|
||||
|
||||
# hack for one kernel threefry
|
||||
#(UPat(Ops.CHILD, src=(UPat(Ops.THREEFRY, name="x"),)), lambda x: x),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class RangeifyContext:
|
||||
idx: int = 0
|
||||
regs: int = 0
|
||||
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
|
||||
seen_child: dict[UOp, Any] = field(default_factory=dict)
|
||||
is_sink_contig: tuple[UOp, ...] = ()
|
||||
|
||||
def map_reshape(x:UOp, r:UOp):
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
mish = sum(to_sum)
|
||||
ret = []
|
||||
for s in r.src[0].shape[::-1]:
|
||||
if resolve(s!=1):
|
||||
# this MOD should limit any ranges outside s
|
||||
ret.append(mish % s)
|
||||
mish //= s
|
||||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
|
||||
return r.src[0].index(*ret, dtype=x.dtype)
|
||||
|
||||
def map_pad(x:UOp, r:UOp):
|
||||
ret = list(x.src[1:])
|
||||
bigwhere = UOp.const(dtypes.bool, True)
|
||||
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
if e > 0: where = where & (ret[i] < (sh-e))
|
||||
if s > 0: where = where & (ret[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
# this is safe but dumb
|
||||
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
|
||||
# mask the load
|
||||
#ret[i] = where.where(ret[i], UOp(Ops.INVALID, dtype=ret[i].dtype))
|
||||
# PAD is with 0
|
||||
return bigwhere.simplify().where(UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret)), UOp.const(r.dtype, 0))
|
||||
|
||||
def map_expand(r:UOp, x:UOp):
|
||||
new_rngs = []
|
||||
ending_ranges = []
|
||||
non_ending_ranges = []
|
||||
for a,x,y in zip(x.src[1:], r.src[0].shape, r.shape):
|
||||
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
|
||||
if resolve(x!=y, False):
|
||||
ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a.const_like(0))
|
||||
else:
|
||||
non_ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a)
|
||||
ending_ranges = [x for x in ending_ranges if x not in non_ending_ranges]
|
||||
ret = r.src[0]
|
||||
ret = UOp(Ops.ENDRANGE, dtype=ret.dtype, src=(ret,)+tuple(ending_ranges)) if len(ending_ranges) else ret
|
||||
return ret.index(*new_rngs)
|
||||
|
||||
pm_mops = PatternMatcher([
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)], dtype=x.dtype)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[x.src[1+p] for p in argsort(x.src[0].arg)])),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)])),
|
||||
# expand needs to end ranges
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="x"), map_expand),
|
||||
# reshape does a lot of symbolic stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
|
||||
# pad adds min and max
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
|
||||
])
|
||||
|
||||
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
|
||||
if x.tag == 1: return None
|
||||
ranges = []
|
||||
new_ranges = []
|
||||
passthrough_idx = []
|
||||
for i,s in enumerate(x.shape):
|
||||
if x.arg is not None and i not in x.arg:
|
||||
assert idx is not None, "partial contig requires index"
|
||||
ranges.append(idx.src[1+i])
|
||||
continue
|
||||
if idx is not None: passthrough_idx.append(idx.src[1+i])
|
||||
if resolve(s!=1):
|
||||
ranges.append(UOp.range(dtypes.int, s, ctx.idx))
|
||||
new_ranges.append(ranges[-1])
|
||||
ctx.idx += 1
|
||||
else:
|
||||
ranges.append(UOp.const(dtypes.int, 0))
|
||||
ret = x.src[0].index(*ranges).pcontiguous(*new_ranges, arg=x.arg)
|
||||
# if there's no open ranges, set arg to None so this uses a DEFINE_GLOBAL
|
||||
if len(ret.ranges) == 0: ret = ret.replace(arg=None)
|
||||
ret = ret.index(*passthrough_idx) if len(passthrough_idx) else ret
|
||||
return ret
|
||||
|
||||
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
||||
# TODO: this should be in the cache
|
||||
#print(f"reduce {id(red)}")
|
||||
rngs = list(idx.src[1:])
|
||||
new_ranges = []
|
||||
for i,s in enumerate(red.src[0].shape):
|
||||
if i in red.arg[1]:
|
||||
rngs[i] = UOp.range(dtypes.int, s, ctx.idx)
|
||||
ctx.idx += 1
|
||||
new_ranges.append(rngs[i])
|
||||
return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0])
|
||||
|
||||
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
||||
#print(f"visit CHILD {x.arg} bottom up")
|
||||
if c not in ctx.seen_children: ctx.seen_children[c] = {}
|
||||
ctx.seen_children[c][x.arg[0]] = idx
|
||||
# wait here until we have seen all the children
|
||||
if len(ctx.seen_children[c]) != x.arg[1]: raise RewriteNotReady
|
||||
|
||||
if c not in ctx.seen_child:
|
||||
all_rngs = zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()])
|
||||
out_rngs = []
|
||||
end_ranges = []
|
||||
idx_ranges = []
|
||||
for i,r in enumerate(all_rngs):
|
||||
if all_same(r):
|
||||
out_rngs.append(r[0])
|
||||
else:
|
||||
out_rngs.append(UOp.range(dtypes.int, c.shape[i], ctx.idx))
|
||||
ctx.idx += 1
|
||||
end_ranges.append(out_rngs[-1])
|
||||
idx_ranges.append(i)
|
||||
ctx.seen_child[c] = (idx_ranges, end_ranges)
|
||||
else:
|
||||
out_rngs = list(idx.src[1:])
|
||||
idx_ranges, end_ranges = ctx.seen_child[c]
|
||||
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
|
||||
if len(idx_ranges) == 0: return c.index(*out_rngs)
|
||||
return c.index(*out_rngs).pcontiguous(*end_ranges, arg=tuple(idx_ranges)).index(*[idx.src[1+i] for i in idx_ranges])
|
||||
|
||||
def indexed_endrange(er:UOp, idx:UOp):
|
||||
ended = er.src[1:]
|
||||
earliest_ending_axis = min([x.arg for x in ended])
|
||||
to_end_axis = []
|
||||
for i,a in enumerate(idx.src[1:]):
|
||||
if any(x.arg > earliest_ending_axis for x in a.toposort() if x.op is Ops.RANGE):
|
||||
to_end_axis.append(i)
|
||||
if to_end_axis: return idx.replace(src=(er.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:])
|
||||
return idx.replace(src=(er.src[0],)+idx.src[1:])
|
||||
|
||||
pm_rangeify = pm_mops+PatternMatcher([
|
||||
# if there are new ended children, tag the SINK
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
|
||||
|
||||
# if there's an INDEX it can support partial contig
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, name="x"),), allow_any_len=True, name="idx"), map_contiguous),
|
||||
|
||||
# sink contigs to kick it off
|
||||
(UPat(Ops.CONTIGUOUS, name="x"), lambda ctx,x: map_contiguous(ctx, x).reshape(x.shape) if x.tag == 2 else None),
|
||||
|
||||
# handle ENDRANGE on movement
|
||||
(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Movement),), allow_any_len=True, name="er"),
|
||||
lambda er: er.src[0].replace(src=(UOp(Ops.ENDRANGE, dtype=er.dtype, src=(er.src[0].src[0],)+er.src[1:]),))),
|
||||
# handle ENDRANGE on BUFFER
|
||||
# and CHILD: python3 test/test_schedule.py TestSchedule.test_cache_reduce_parent
|
||||
(UPat(Ops.ENDRANGE, src=(UPat((Ops.BUFFER, Ops.CONST, Ops.CONTIGUOUS, Ops.CHILD)),), allow_any_len=True, name="er"), lambda er: er.src[0]),
|
||||
# handle INDEXed ENDRANGE
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.ENDRANGE, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="er"),),
|
||||
allow_any_len=True, name="idx"), indexed_endrange),
|
||||
|
||||
# move MAP through elementwise ALU / reduce. these are the items with cost
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
||||
|
||||
# CONTIGUOUS on ASSIGN is STORE
|
||||
# TODO: tag in UPat?
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.ASSIGN, name="a"),), name="c", allow_any_len=True),
|
||||
lambda c,a: UOp(Ops.STORE, src=a.src+c.src[1:]) if c.tag == 1 else None),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class AddBufferContext:
|
||||
dg:int = 0
|
||||
map:dict = field(default_factory=dict)
|
||||
|
||||
def add_store(ctx:AddBufferContext, x:UOp):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([r.vmax+1 for r in rngs])
|
||||
assert prod(shape) > 0, f"no zero sized buffers {shape}"
|
||||
if x.arg is None or prod(shape) > 65536:
|
||||
buf = UOp.new_buffer(x.device, prod(shape), x.dtype)
|
||||
else:
|
||||
buf = UOp(Ops.DEFINE_LOCAL, dtype=x.dtype.ptr(size=prod(shape), addrspace=AddrSpace.LOCAL), arg=ctx.dg)
|
||||
ctx.map[buf] = (buf.op, ctx.dg)
|
||||
ctx.dg += 1
|
||||
return buf.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0], *rngs)
|
||||
|
||||
def add_load(ctx:AddBufferContext, x:UOp, b:UOp, idx:UOp):
|
||||
if isinstance(x.dtype, PtrDType): return None
|
||||
return x.replace(dtype=x.dtype.ptr(b.size)).load()
|
||||
|
||||
def add_load_on_store(ctx:AddBufferContext, x:UOp, st:UOp):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([r.vmax+1 for r in rngs])
|
||||
b = st.src[0].src[0]
|
||||
assert b.op is Ops.BUFFER
|
||||
return b.shrink(((0,prod(shape)),)).reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=b.size)).load(st)
|
||||
|
||||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.PCONTIGUOUS, name="x"), add_store),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat(name="idx")), name="x"), add_load),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
|
||||
(UPat(Ops.BIND, name="b"), lambda b: b.src[0]),
|
||||
# CONST can't have axes. remove srcs when we idx
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),
|
||||
# HACK: consts shouldn't have srcs by here
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(src=()) if len(x.src) else None),
|
||||
])
|
||||
|
|
@ -19,6 +19,7 @@ class Ops(FastEnum):
|
|||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
|
||||
PCONTIGUOUS = auto()
|
||||
|
||||
# blocks in linearizer (only used there)
|
||||
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
|
||||
|
|
|
|||
|
|
@ -136,10 +136,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
@functools.cached_property
|
||||
def st(self) -> ShapeTracker|None:
|
||||
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
|
||||
if self.op is Ops.INDEX and self.src[0].op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER}: return None
|
||||
if self.op in GroupOp.Block: return None
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# VIEW and MovementOps define a new ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg)
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
# CONST with a DEVICE has a shape of ()
|
||||
if self.op is Ops.CONST and len(self.src) and self.src[0].op is Ops.DEVICE: return ShapeTracker.from_shape(())
|
||||
|
|
@ -158,7 +160,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
|
||||
|
||||
# otherwise we get the shape from sources
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None and x.op is not Ops.INDEX]): return None
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
match self.op:
|
||||
case Ops.MULTI: shape = tuple(self.src[0].shape[a]*len(self.device) if a == self.axis else s for a,s in enumerate(self.src[0].shape))
|
||||
|
|
@ -186,7 +188,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
@functools.cached_property
|
||||
def ranges(self) -> dict[UOp, None]:
|
||||
if self.op is Ops.RANGE: return {self:None}
|
||||
if self.op in {Ops.CONTIGUOUS, Ops.REDUCE, Ops.STORE}:
|
||||
if self.op in {Ops.PCONTIGUOUS, Ops.REDUCE, Ops.STORE}:
|
||||
ret = self.src[0].ranges.copy()
|
||||
for s in self.src[1:]:
|
||||
if s in ret: del ret[s]
|
||||
|
|
@ -232,7 +234,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None): return UOp(Ops.INDEX, self.dtype, (self,)+tuple([x for x in srcs if x is not None]))
|
||||
def index(self, *srcs:UOp|None, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx): return self.index(idx)
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
|
|
@ -269,11 +272,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
||||
# TODO: clean this all up with rangeify
|
||||
if shape is not None:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
|
||||
if device is not None:
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
if shape is not None:
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
else:
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
|
|
@ -291,6 +298,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||
def pcontiguous(self, *args, **kwargs): return UOp(Ops.PCONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
||||
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
|
||||
|
|
@ -359,7 +367,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
def _mop(self, op:Ops, arg) -> UOp:
|
||||
ret = UOp(op, self.dtype, (self,), arg)
|
||||
if self.st == ret.st: return self # ignore NOOPs, also check ret.st
|
||||
if self.st is not None and self.st == ret.st: return self # ignore NOOPs, also check ret.st
|
||||
return ret
|
||||
|
||||
def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
|
||||
|
|
@ -935,7 +943,7 @@ class RewriteContext:
|
|||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
try: self.replace[n] = self.replace[new_n]
|
||||
except KeyError: raise RuntimeError("infinite loop in graph_rewrite (explicit)") # pylint: disable=raise-missing-from
|
||||
except KeyError: raise RewriteNotReady # pylint: disable=raise-missing-from
|
||||
except RewriteNotReady:
|
||||
# retry this later
|
||||
stack.insert(0, (n, stage, new_n))
|
||||
|
|
@ -951,7 +959,7 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
|
|||
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
new_map: dict[UOp, UOp] = {}
|
||||
for k in sink.toposort():
|
||||
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
|
||||
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
|
||||
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
|
||||
if input_map is not None:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.CHILD: "#80fff0"}
|
||||
Ops.PCONTIGUOUS: "#FFC18D", Ops.CHILD: "#80fff0"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue