mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
37 commits
master
...
lil_rangei
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7ba7fcde1 | ||
|
|
86d7f7d224 | ||
|
|
a5e6c7dbc9 | ||
|
|
707d8d9d72 | ||
|
|
40767e8f92 | ||
|
|
41ab1e0852 | ||
|
|
465eff25de | ||
|
|
57014d2302 | ||
|
|
06fe3a2d57 | ||
|
|
778358eff6 | ||
|
|
c4d565b591 | ||
|
|
6ef39f4657 | ||
|
|
49cd68945c | ||
|
|
08ec4de6a3 | ||
|
|
7dc708e735 | ||
|
|
65cbf9d785 | ||
|
|
b3f5852fae | ||
|
|
780a49ebfb | ||
|
|
d1d4bbe179 | ||
|
|
998775507d | ||
|
|
14e055a8b6 | ||
|
|
8418d06300 | ||
|
|
4a6c2e5d68 | ||
|
|
801712880e |
||
|
|
20f8fe4443 | ||
|
|
6425ed93c8 | ||
|
|
e48c87ba71 | ||
|
|
717b0d107c | ||
|
|
2b93b50710 | ||
|
|
e4884845a3 | ||
|
|
d040d55960 |
||
|
|
c707e87da0 | ||
|
|
155e97045d | ||
|
|
05f04bbcc3 | ||
|
|
551b34bb0f | ||
|
|
0575c95389 | ||
|
|
18130dae2a |
10 changed files with 833 additions and 15 deletions
|
|
@ -29,8 +29,7 @@ if __name__ == "__main__":
|
|||
opt.zero_grad()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
|
||||
opt.step()
|
||||
return loss
|
||||
return loss.realize(*opt.schedule_step())
|
||||
|
||||
@TinyJit
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
|
|
|||
113
test/test_rangeify.py
Normal file
113
test/test_rangeify.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor
|
||||
|
||||
class TestRangeify(unittest.TestCase):
|
||||
def test_add(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
(A+B).realize()
|
||||
|
||||
def test_double_gemm(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(A@B@C).realize()
|
||||
|
||||
def test_double_gemm_exp(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(((A@B).exp()@C).exp()).realize()
|
||||
|
||||
def test_double_gemm_relu(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(((A@B).relu()@C).relu()).realize()
|
||||
|
||||
def test_double_gemm_relu_half_contig(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
(((A@B).relu().contiguous(arg=(1,))@C).relu()).realize()
|
||||
|
||||
def test_double_gemm_half_contig(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
((A@B).contiguous(arg=(1,))@C).realize()
|
||||
|
||||
def test_double_gemm_contig(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
((A@B).contiguous()@C).realize()
|
||||
|
||||
def test_many_gemm(self):
|
||||
N = 1024
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
C = Tensor.empty(N, N)
|
||||
D = Tensor.empty(N, N)
|
||||
E = Tensor.empty(N, N)
|
||||
F = Tensor.empty(N, N)
|
||||
(A@B@C@D@E@F).realize()
|
||||
|
||||
def test_conv2d(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
x.conv2d(w1).realize()
|
||||
|
||||
def test_conv2d_t(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
(x*2).conv2d(w1).realize()
|
||||
|
||||
def test_double_conv2d(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).conv2d(w2).realize()
|
||||
|
||||
def test_double_conv2d_half_contig(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
# NOTE: this contiguous doesn't help
|
||||
x.conv2d(w1).contiguous(arg=(1,)).conv2d(w2).permute(0,2,3,1).contiguous().realize()
|
||||
|
||||
def test_double_conv2d_contig(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).contiguous().conv2d(w2).realize()
|
||||
|
||||
def test_ffn(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
from tinygrad import nn
|
||||
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)
|
||||
for p in nn.state.get_parameters(blk): p.replace(Tensor.empty(p.shape))
|
||||
|
||||
x = Tensor.empty(128, 1024)
|
||||
out = blk._feed_forward(x)
|
||||
out.realize()
|
||||
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
HEADS = 2
|
||||
MATDIM = 16
|
||||
EMB = 8
|
||||
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1044,7 +1044,7 @@ class TestSchedule(unittest.TestCase):
|
|||
k = Tensor.randn(32,8,16,8).realize()
|
||||
v = Tensor.randn(32,8,16,8).realize()
|
||||
out = Tensor.scaled_dot_product_attention(q,k,v)
|
||||
run_schedule(check_schedule(out, 5))
|
||||
#run_schedule(check_schedule(out, 5))
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))
|
||||
|
|
|
|||
|
|
@ -100,12 +100,30 @@ class TestTiny(unittest.TestCase):
|
|||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
# replace random weights with ones
|
||||
Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
|
||||
for p in nn.state.get_parameters(layers): p.replace(Tensor.empty(p.shape))
|
||||
|
||||
# run model inference
|
||||
probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
|
||||
probs = Tensor.empty(1, 1, 28, 28).sequential(layers).tolist()
|
||||
self.assertEqual(len(probs[0]), 10)
|
||||
|
||||
# TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
|
||||
@unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images")
|
||||
def test_mnist_backward(self):
|
||||
# NOTE: we don't have the whole model here for speed
|
||||
layers = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu]
|
||||
|
||||
# replace random weights with ones
|
||||
# TODO: there's a bug here where it's tying two of the biases together. we need UNIQUE const
|
||||
for p in nn.state.get_parameters(layers): p.replace(Tensor.empty(p.shape))
|
||||
#for p in nn.state.get_parameters(layers): p.replace(Tensor.ones_like(p).contiguous().realize())
|
||||
|
||||
# realize gradients
|
||||
for x in nn.state.get_parameters(layers): x.requires_grad_()
|
||||
Tensor.empty(4, 1, 28, 28).sequential(layers).sum().backward()
|
||||
Tensor.realize(*[x.grad for x in nn.state.get_parameters(layers) if x.grad is not None])
|
||||
|
||||
# *** image ***
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
|
|
|
|||
|
|
@ -104,7 +104,9 @@ class CStyleLanguage(Renderer):
|
|||
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
|
||||
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
|
||||
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})"}
|
||||
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})",
|
||||
# NOTE: these don't work, but they are nice for rendering
|
||||
Ops.THREEFRY: lambda a,b,dtype: f"threefry({a},{b})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"}
|
||||
|
||||
string_rewrite = base_rewrite
|
||||
extra_matcher = extra_pm
|
||||
|
|
|
|||
663
tinygrad/schedule/rangeify.py
Normal file
663
tinygrad/schedule/rangeify.py
Normal file
|
|
@ -0,0 +1,663 @@
|
|||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, flatten, dedup
|
||||
from tinygrad.uop.symbolic import symbolic_simple, sym
|
||||
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element
|
||||
|
||||
imported_rewrites = PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
])
|
||||
|
||||
earliest_rewrites = imported_rewrites+PatternMatcher([
|
||||
# RESHAPE on RESHAPE is the second reshape
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
|
||||
# non shape changing RESHAPE is NOOP
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
||||
# RESHAPE after COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
|
||||
# TODO: this should be BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
|
||||
# const hacks
|
||||
(UPat(Ops.CONST, name="x"), lambda x:
|
||||
x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
|
||||
len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
|
||||
# assign only to buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}), UPat(name="x"))), lambda x: x if x.src[0].base.op is not Ops.BUFFER else None),
|
||||
])
|
||||
|
||||
# 1. add contiguous where we have to
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
|
||||
|
||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
||||
|
||||
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
for s in rb.src:
|
||||
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
||||
# realize parents of COPY, MSELECT, MSTACK
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
|
||||
])
|
||||
|
||||
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS}, name="x"),
|
||||
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
early_cleanups = PatternMatcher([(UPat().contiguous(name="c").contiguous(), lambda c: c),])
|
||||
|
||||
# 2. mark all children
|
||||
|
||||
@dataclass
|
||||
class ChildrenContext: children: dict[UOp, list[UOp]]|None = None
|
||||
def extract_children(ctx:ChildrenContext, x:UOp):
|
||||
if ctx.children is not None: return
|
||||
children_map = x.get_children_map()
|
||||
ctx.children = {}
|
||||
for k,v in children_map.items():
|
||||
non_sink_children = [u for u in v if u.op is not Ops.SINK]
|
||||
if len(non_sink_children) <= 1: continue
|
||||
if any(x.op is Ops.REDUCE_AXIS for x in k.toposort()):
|
||||
ctx.children[k] = non_sink_children
|
||||
|
||||
def mark_children(ctx:ChildrenContext, x:UOp):
|
||||
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(UOp(Ops.CHILDREN, s.dtype, (s,), arg=len(ctx.children[s])),),
|
||||
arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
|
||||
return x.replace(src=tuple(new_srcs))
|
||||
|
||||
pm_children = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), extract_children),
|
||||
(UPat(GroupOp.All-{Ops.CHILD, Ops.CHILDREN}, name="x"), mark_children),
|
||||
|
||||
# hack for one kernel threefry
|
||||
#(UPat(Ops.CHILD, src=(UPat(Ops.THREEFRY, name="x"),)), lambda x: x),
|
||||
])
|
||||
|
||||
# 3. rangeify
|
||||
|
||||
@dataclass
|
||||
class RangeifyContext:
|
||||
idx: int = 0
|
||||
regs: int = 0
|
||||
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
|
||||
seen_child: dict[UOp, Any] = field(default_factory=dict)
|
||||
progress: int = 0
|
||||
children: dict[UOp, list[UOp]]|None = None
|
||||
|
||||
def map_reshape(idx:UOp, r:UOp):
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
mish = sum(to_sum)
|
||||
ret = []
|
||||
for s in r.src[0].shape[::-1]:
|
||||
if resolve(s!=1):
|
||||
# this MOD should limit any ranges outside s
|
||||
ret.append(mish % s)
|
||||
mish //= s
|
||||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
|
||||
return r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg)
|
||||
|
||||
def map_pad(idx:UOp, r:UOp):
|
||||
ret = list(idx.src[1:])
|
||||
bigwhere = UOp.const(dtypes.bool, True)
|
||||
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
if e > 0: where = where & (ret[i] < (sh-e))
|
||||
if s > 0: where = where & (ret[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
# this is safe but dumb
|
||||
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
|
||||
# PAD is with 0
|
||||
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
|
||||
|
||||
def map_expand(r:UOp, idx:UOp):
|
||||
new_rngs = []
|
||||
ending_ranges = []
|
||||
non_ending_ranges = []
|
||||
for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape):
|
||||
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
|
||||
if resolve(x!=y, False):
|
||||
ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a.const_like(0))
|
||||
else:
|
||||
non_ending_ranges.extend(axis_to_range)
|
||||
new_rngs.append(a)
|
||||
ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
|
||||
if idx.arg is not None: ending_ranges.append(idx.arg)
|
||||
return r.src[0].index(*new_rngs, arg=min([x for x in ending_ranges]) if ending_ranges else None)
|
||||
|
||||
pm_mops = PatternMatcher([
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
|
||||
# expand needs to end ranges
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="idx"), map_expand),
|
||||
# reshape does a lot of symbolic stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="idx"), map_reshape),
|
||||
# pad adds min and max
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="idx"), map_pad),
|
||||
])
|
||||
|
||||
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
|
||||
# NOTE: partial contig is disabled for now
|
||||
#arg = x.arg
|
||||
arg = None
|
||||
if arg is None and idx is not None: return None
|
||||
if arg is not None and idx is None: return None
|
||||
|
||||
ranges = []
|
||||
new_ranges = []
|
||||
passthrough_idx = []
|
||||
for i,s in enumerate(x.shape):
|
||||
if arg is not None and i not in arg:
|
||||
assert idx is not None, "partial contig requires index"
|
||||
ranges.append(idx.src[1+i])
|
||||
continue
|
||||
if idx is not None: passthrough_idx.append(idx.src[1+i])
|
||||
if resolve(s!=1):
|
||||
ranges.append(UOp.range(dtypes.int, s, ctx.idx))
|
||||
new_ranges.append(ranges[-1])
|
||||
ctx.idx += 1
|
||||
else:
|
||||
ranges.append(UOp.const(dtypes.int, 0))
|
||||
ret = x.src[0].index(*ranges).bufferize(*new_ranges, arg=x.device)
|
||||
ret = ret.index(*passthrough_idx) if len(passthrough_idx) else ret.reshape(x.shape)
|
||||
return ret
|
||||
|
||||
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
||||
rngs = list(idx.src[1:])
|
||||
new_ranges = []
|
||||
for i,s in enumerate(red.src[0].shape):
|
||||
if i in red.arg[1]:
|
||||
rngs[i] = UOp.range(dtypes.int, s, ctx.idx)
|
||||
ctx.idx += 1
|
||||
new_ranges.append(rngs[i])
|
||||
return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0])
|
||||
|
||||
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
||||
if c not in ctx.seen_children: ctx.seen_children[c] = {}
|
||||
# wait here until we have seen all the children
|
||||
if len(ctx.seen_children[c]) != x.arg[1]:
|
||||
ctx.progress += 1
|
||||
if ctx.progress > 10000: raise RuntimeError("children not making progress")
|
||||
# NOTE: we mark this here
|
||||
ctx.seen_children[c][x.arg[0]] = idx
|
||||
raise RewriteNotReady
|
||||
ctx.progress = 0
|
||||
|
||||
if c not in ctx.seen_child:
|
||||
all_rngs = zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()])
|
||||
out_rngs = []
|
||||
end_ranges = []
|
||||
idx_ranges = []
|
||||
for i,r in enumerate(all_rngs):
|
||||
if all_same(r):
|
||||
out_rngs.append(r[0])
|
||||
else:
|
||||
out_rngs.append(UOp.range(dtypes.int, c.shape[i], ctx.idx))
|
||||
ctx.idx += 1
|
||||
end_ranges.append(out_rngs[-1])
|
||||
idx_ranges.append(i)
|
||||
ctx.seen_child[c] = (idx_ranges, end_ranges)
|
||||
else:
|
||||
out_rngs = list(idx.src[1:])
|
||||
idx_ranges, end_ranges = ctx.seen_child[c]
|
||||
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
|
||||
if len(idx_ranges) == 0: return c.index(*out_rngs)
|
||||
# NOTE: partial contigs can still come from here
|
||||
return c.index(*out_rngs).bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
|
||||
|
||||
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
|
||||
if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now")
|
||||
return idx.replace(src=(idx.src[0].src[0],)+idx.src[1:])
|
||||
|
||||
def might_end_axis(idx:UOp):
|
||||
if idx.arg is None: return None
|
||||
to_end_axis = []
|
||||
for i,a in enumerate(idx.src[1:]):
|
||||
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
|
||||
to_end_axis.append(i)
|
||||
if to_end_axis: return idx.replace(src=(idx.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
|
||||
return idx.replace(arg=None)
|
||||
|
||||
pm_rangeify = pm_mops+PatternMatcher([
|
||||
# sink contigs to kick it off
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x"), map_contiguous),
|
||||
|
||||
# if there are new ended children, tag the SINK
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILDREN, name="c"),), allow_any_len=True, name="idx"), children_gate),
|
||||
|
||||
# if we come across this, remove it. it was a CHILD unused in an INDEX
|
||||
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
|
||||
|
||||
# if there's an INDEX it can support partial contig
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_contiguous),
|
||||
|
||||
# CONST can't have axes. remove srcs when we idx
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),
|
||||
|
||||
# handle arg on any op with weight. old endrange stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
|
||||
|
||||
# move MAP through elementwise ALU / reduce. these are the items with cost
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
||||
])
|
||||
|
||||
# 4. remove bufferize
|
||||
|
||||
def bufferize_to_store(ctx, x:UOp):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([r.vmax+1 for r in rngs])
|
||||
assert prod(shape) > 0, f"no zero sized buffers {shape}"
|
||||
store_rngs = [x for x in UOp.sink(*rngs).toposort() if x.op is Ops.RANGE]
|
||||
if x.src[0].op is Ops.ASSIGN:
|
||||
return x.src[0].src[0].replace(dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0].src[1], *store_rngs)
|
||||
#buf = UOp.new_buffer(x.arg, prod(shape), x.dtype)
|
||||
buf = UOp(Ops.DEFINE_LOCAL, x.dtype.ptr(size=prod(shape)), arg=ctx[0])
|
||||
ctx[0] += 1
|
||||
return buf.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0], *store_rngs)
|
||||
|
||||
def add_load_on_buffer(idx:UOp, b:UOp):
|
||||
if isinstance(idx.dtype, PtrDType): return None
|
||||
return idx.replace(dtype=idx.dtype.ptr(b.size), arg=None).load()
|
||||
|
||||
def add_load_on_store(x:UOp, st:UOp):
|
||||
if isinstance(x.dtype, PtrDType): return None
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([r.vmax+1 for r in rngs])
|
||||
b = st.src[0].src[0]
|
||||
#assert b.op is Ops.BUFFER
|
||||
return b.shrink(((0,prod(shape)),)).reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=b.size)).load(st)
|
||||
|
||||
def shp(shp, rng):
|
||||
acc = 1
|
||||
ss = []
|
||||
for s,r in list(zip(shp,rng))[::-1]:
|
||||
ss.append(r*acc)
|
||||
acc *= s
|
||||
return sum(ss)
|
||||
|
||||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat()), name="idx"), add_load_on_buffer),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
|
||||
|
||||
# HACK
|
||||
#(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
||||
(UPat(Ops.INDEX, name="idx").contiguous(),
|
||||
lambda idx: UOp.new_buffer(idx.device, prod(idx.arg), idx.dtype).index(shp(idx.arg, idx.src[1:]),
|
||||
dtype=idx.dtype.ptr(prod(idx.arg))).store(*idx.src))
|
||||
])
|
||||
|
||||
# 5 (alt). create pointers
|
||||
|
||||
def debuf(ctx, b:UOp):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx[0])
|
||||
ctx[0] += 1
|
||||
return ret
|
||||
|
||||
pm_debuf = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="b"), debuf),
|
||||
# HACK: consts shouldn't have srcs by here
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.replace(src=()) if len(x.src) else None),
|
||||
# no movement ops
|
||||
(UPat(GroupOp.Movement, name="x"), lambda x: x.src[0]),
|
||||
# HACK: no copy
|
||||
(UPat(Ops.COPY, name="x"), lambda x: x.src[0]),
|
||||
])
|
||||
|
||||
# 5. split into kernels
|
||||
|
||||
@dataclass
|
||||
class LocalAddBufferContext:
|
||||
dg:int = 0
|
||||
map:dict = field(default_factory=dict)
|
||||
vars:dict = field(default_factory=dict)
|
||||
|
||||
def debuf(ctx:LocalAddBufferContext, b:UOp): return UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx.map[b][1])
|
||||
def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
|
||||
ctx.vars[b] = None
|
||||
return b.src[0]
|
||||
|
||||
def split_load(ctx:LocalAddBufferContext, s:UOp):
|
||||
b = s.src[0].src[0]
|
||||
if b.op is not Ops.BUFFER: return None
|
||||
|
||||
if len(s.src) == 2 and s.src[1].op is Ops.ASSIGN:
|
||||
assert len(s.src) == 2
|
||||
lb = s.src[1]
|
||||
assert b not in ctx.map or ctx.map[b][0] == lb
|
||||
else:
|
||||
lb = b
|
||||
if b not in ctx.map:
|
||||
ctx.map[b] = (lb, ctx.dg)
|
||||
ctx.dg += 1
|
||||
return s.replace(src=s.src[0:1]) if b is not lb else None
|
||||
|
||||
def handle_store(ctx:LocalAddBufferContext, s:UOp):
|
||||
b = s.src[0].src[0]
|
||||
if b.op is not Ops.BUFFER: return None
|
||||
if b not in ctx.map:
|
||||
ctx.map[b] = (b, ctx.dg)
|
||||
ctx.dg += 1
|
||||
if s.src[1].op is Ops.COPY: return s.src[1]
|
||||
return None
|
||||
|
||||
to_define_global = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="b"), debuf),
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat(Ops.LOAD, name="s"), split_load),
|
||||
(UPat(Ops.STORE, name="s"), handle_store),
|
||||
])
|
||||
|
||||
def split_store(x:UOp):
|
||||
if len(x.ranges): return None
|
||||
store_rngs = x.src[2:]
|
||||
|
||||
ctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global, ctx=ctx, name="kernel split", bottom_up=True)
|
||||
rng = sorted([u for u in ret.toposort() if u.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
name = "k"+colored('_', 'BLACK').join(['']+[colored(str(s.vmax+1), "WHITE") if s in store_rngs else colored(str(s.vmax+1), "red") for s in rng])
|
||||
|
||||
ret = ret.sink(arg=KernelInfo(name=name)) if ret.op is Ops.STORE else ret
|
||||
kernel = UOp(Ops.KERNEL, src=tuple([x[0] for x in ctx.map.values()])+tuple(ctx.vars.keys()), arg=Kernel(ret, ()))
|
||||
return kernel.src[0].assign(kernel)
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="x"), split_store),
|
||||
])
|
||||
|
||||
pm_children_fixup = PatternMatcher([
|
||||
# clone all movement ops
|
||||
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat(GroupOp.Movement, name="m"),)),), name="c"),
|
||||
lambda c,m: UOp(m.op, m.dtype, (c.replace(src=(c.src[0].replace(src=(m.src[0],)),)),), m.arg)),
|
||||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class RContext:
|
||||
range_num = 0
|
||||
|
||||
def new_range(ctx, s):
|
||||
ret = UOp.range(dtypes.int, s, ctx.range_num)
|
||||
ctx.range_num += 1
|
||||
return ret
|
||||
|
||||
def td_reshape(ctx, idx:UOp, r:UOp):
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,i in list(zip(idx.arg, idx.src[1:]))[::-1]:
|
||||
to_sum.append(i*acc)
|
||||
acc *= s
|
||||
mish = sum(to_sum)
|
||||
ret = []
|
||||
for s in r.arg[::-1]:
|
||||
if resolve(s!=1):
|
||||
# this MOD should limit any ranges outside s
|
||||
ret.append(mish % s)
|
||||
mish //= s
|
||||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
|
||||
ii = idx.src[0]
|
||||
out_rng = ret
|
||||
|
||||
"""
|
||||
out_rng = []
|
||||
for i,rr in enumerate(ret):
|
||||
if rr.op not in {Ops.RANGE, Ops.CONST}:
|
||||
out_rng.append(new_range(ctx, r.arg[i]))
|
||||
else:
|
||||
out_rng.append(rr)
|
||||
|
||||
mm = [idx.src[0]]
|
||||
for x,y in zip(ret, out_rng):
|
||||
if x is not y:
|
||||
mm.append(x)
|
||||
mm.append(y)
|
||||
if len(mm) > 1:
|
||||
ii = UOp(Ops.MERGE, idx.dtype, tuple(mm))
|
||||
"""
|
||||
return ii.index(*out_rng, dtype=idx.dtype, arg=r.arg)
|
||||
|
||||
def td_elementwise(ctx, e:UOp):
|
||||
# if the range is closed by a reduce to the left, we can't reuse it
|
||||
# TODO: handle composite ranges better
|
||||
reduces_left = flatten([x.src[1:] for x in e.toposort() if x.op is Ops.REDUCE])
|
||||
shps = [u.arg for u in e.src]
|
||||
assert all_same(shps)
|
||||
rngs = [u.src[1:] for u in e.src]
|
||||
out_rng = []
|
||||
need_merge = False
|
||||
for i,r in enumerate(zip(*rngs)):
|
||||
r = [x for x in r if x is not UOp.const(dtypes.int, 0)]
|
||||
if len(r) == 0:
|
||||
out_rng.append(UOp.const(dtypes.int, 0))
|
||||
elif all_same(r) and r[0] not in reduces_left:
|
||||
out_rng.append(r[0])
|
||||
else:
|
||||
out_rng.append(new_range(ctx, shps[0][i]))
|
||||
need_merge = True
|
||||
if need_merge:
|
||||
new_src = []
|
||||
for u in e.src:
|
||||
assert u.op is Ops.INDEX
|
||||
out = [u.src[0]]
|
||||
rngs_in_src = [x for x in out[0].toposort() if x.op is Ops.RANGE]
|
||||
for i,idx in list(enumerate(u.src[1:]))[::-1]:
|
||||
rngs_in_idx = [x for x in idx.toposort() if x.op is Ops.RANGE]
|
||||
if all(x not in rngs_in_src for x in rngs_in_idx):
|
||||
# for expands
|
||||
continue
|
||||
if idx is not out_rng[i]:
|
||||
out.append(idx)
|
||||
out.append(out_rng[i])
|
||||
#out = UOp(Ops.MERGE, out.dtype, src=(out, idx, out_rng[i]))
|
||||
if len(out) > 1:
|
||||
new_src.append(UOp(Ops.MERGE, u.dtype, tuple(out)))
|
||||
else:
|
||||
new_src.append(out[0])
|
||||
#mm = []
|
||||
#for i,idx in enumerate(u.src[1:]):
|
||||
# if idx is not out_rng[i] and idx is not UOp.const(dtypes.int, 0):
|
||||
# mm.append(UOp(Ops.MERGE, src=(idx, out_rng[i])))
|
||||
#new_src.append(UOp(Ops.MBLOCK, u.dtype, (u.src[0],)+tuple(mm)))
|
||||
else:
|
||||
new_src = list([x.src[0] for x in e.src])
|
||||
return e.replace(src=tuple(new_src)).index(*out_rng, arg=shps[0])
|
||||
|
||||
def td_shrink(idx:UOp, r:UOp):
|
||||
ret = []
|
||||
shp = []
|
||||
for u,(s,e),shape in zip(idx.src[1:], r.arg, idx.arg):
|
||||
assert s == 0
|
||||
#if u.vmax >= e: u = (u<e).where(u, UOp(Ops.INVALID, u.dtype))
|
||||
ret.append(u)
|
||||
shp.append(min(shape, e))
|
||||
return idx.src[0].index(*ret, dtype=idx.dtype, arg=tuple(shp))
|
||||
|
||||
def td_reduce(ctx, idx:UOp, r:UOp):
|
||||
rngs = idx.src[1:]
|
||||
new_shp = tuple([s if i not in r.arg[1] else 1 for i,s in enumerate(idx.arg)])
|
||||
return UOp(Ops.REDUCE, r.dtype, (idx.src[0],)+tuple([x for i,x in enumerate(rngs) if i in r.arg[1]]),
|
||||
r.arg[0]).index(*[x if i not in r.arg[1] else UOp.const(dtypes.int, 0) for i,x in enumerate(rngs)], arg=new_shp)
|
||||
|
||||
pm_td_rangeify = PatternMatcher([
|
||||
#(UPat(Ops.INDEX, src=(UPat(Ops.MERGE, src=(UPat(Ops.LOAD, name="b"),), allow_any_len=True),), allow_any_len=True, name="idx"),
|
||||
# lambda idx,b: b.src[0].src[0].index(*idx.src[1:], dtype=b.src[0].dtype).load().index(*idx.src[1:], arg=idx.arg)),
|
||||
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
|
||||
b.replace(tag=1).index(nr:=new_range(ctx, b.size), dtype=b.dtype.ptr(size=b.size)).load().index(nr, arg=(b.size,)) if b.tag is None else None),
|
||||
#b.replace(tag=1).index(new_range(ctx, b.size), arg=(b.size,)) if b.tag is None else None),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="c"), lambda c: c.replace(src=()).index(arg=())),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_reshape),
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_shrink),
|
||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
|
||||
lambda r,idx: idx.src[0].index(*[idx.src[1+p] for p in r.arg], dtype=idx.dtype, arg=tuple(idx.arg[p] for p in r.arg))),
|
||||
# 0s are already in place for EXPAND
|
||||
#(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"), lambda r,idx: idx.replace(arg=r.arg)),
|
||||
(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
|
||||
lambda ctx,r,idx: idx.src[0].index(*[ii if s1==s2 else new_range(ctx, s1) for s1,s2,ii in zip(r.arg, idx.arg, idx.src[1:])], arg=r.arg)),
|
||||
(UPat(GroupOp.Elementwise, src=UPat(Ops.INDEX), name="e"), td_elementwise),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_reduce),
|
||||
])
|
||||
|
||||
def remove_merge(m):
|
||||
tr0, tr1 = [], []
|
||||
for r0,r1 in zip(m.src[1::2], m.src[2::2]):
|
||||
if r0 is r1: continue
|
||||
tr0.append(r0)
|
||||
tr1.append(r1)
|
||||
if m.src[0].op is Ops.LOAD and False:
|
||||
# hack for LOAD
|
||||
reps = {k:v for k,v in zip(tr0, tr1)}
|
||||
return m.src[0].substitute(reps)
|
||||
return UOp(Ops.BUFFERIZE, m.dtype, (m.src[0],)+tuple(tr0), arg=m.device).index(*tr1)
|
||||
|
||||
no_merge = PatternMatcher([
|
||||
(UPat(Ops.MERGE, name="m"), remove_merge),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
tensor_map = graph_rewrite_map(sink, earliest_rewrites, name="earliest")
|
||||
realize_map = {}
|
||||
graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add contiguous")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], early_cleanups+remove_tags, input_map=tensor_map, name="cleanup")
|
||||
rsink = tensor_map[sink]
|
||||
|
||||
ctx = RContext()
|
||||
rsink = graph_rewrite(rsink, pm_td_rangeify, ctx=ctx, name="td rangeify")
|
||||
rsink = graph_rewrite(rsink, sym, name="symbolic")
|
||||
|
||||
# find MOD on RANGE to split
|
||||
while 1:
|
||||
#break
|
||||
reps = {}
|
||||
for u in rsink.toposort():
|
||||
if u.op is Ops.MOD and u.src[0].op is Ops.RANGE and u.src[1].op is Ops.CONST:
|
||||
r = u.src[0].vmax+1
|
||||
c = u.src[1].arg
|
||||
if r%c == 0:
|
||||
reps[u.src[0]] = new_range(ctx, r//c)*c + new_range(ctx, c)
|
||||
print(len(reps))
|
||||
if len(reps) == 0: break
|
||||
rsink = rsink.substitute(reps)
|
||||
rsink = graph_rewrite(rsink, sym, name="symbolic")
|
||||
|
||||
for i in range(0):
|
||||
print("loop")
|
||||
real_rngs = rsink.ranges.copy()
|
||||
for u in rsink.toposort():
|
||||
if u.op is Ops.REDUCE:
|
||||
for s in u.src[1:]: real_rngs[s] = None
|
||||
real_rngs = {x:[] for x in real_rngs}
|
||||
print("unmovable", [x.arg for x in real_rngs])
|
||||
|
||||
for u in rsink.toposort():
|
||||
if u.op is not Ops.MERGE: continue
|
||||
assert all(x.op is Ops.RANGE for x in u.src)
|
||||
r0, r1 = [x for x in u.src]
|
||||
if r0 is r1: continue
|
||||
if r0 in real_rngs: real_rngs[r0].append(r1)
|
||||
if r1 in real_rngs: real_rngs[r1].append(r0)
|
||||
|
||||
rew = {}
|
||||
for k,v in real_rngs.items():
|
||||
print(k.arg, [x.arg for x in v])
|
||||
for u in v:
|
||||
rew[u] = k
|
||||
rsink = rsink.substitute(rew)
|
||||
|
||||
|
||||
"""
|
||||
rngs = [x for x in rsink.toposort() if x.op is Ops.RANGE]
|
||||
mmap = {16:1000, 2:8, 3:9}
|
||||
rep = {}
|
||||
for x in rngs:
|
||||
if x.arg in mmap:
|
||||
rep[x] = x.replace(arg=mmap[x.arg])
|
||||
rsink = rsink.substitute(rep)
|
||||
"""
|
||||
|
||||
rsink = graph_rewrite(rsink, no_merge, name="remove merge")
|
||||
|
||||
"""
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="children")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children_fixup, bottom_up=True, input_map=tensor_map, name="fixup children")
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="rangeify")
|
||||
"""
|
||||
#tensor_map = graph_rewrite_map(tensor_map[sink], symbolic_simple, input_map=tensor_map, name="symbolic")
|
||||
#tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, bottom_up=True, input_map=tensor_map, name="add buffers")
|
||||
#if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Rangeify Graph")
|
||||
if getenv("VIZ"): graph_rewrite(rsink, PatternMatcher([]), name="View Rangeify Graph")
|
||||
|
||||
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=[0], bottom_up=True, name="add buffers")
|
||||
|
||||
# render
|
||||
if getenv("SRC") or True:
|
||||
#rsink = tensor_map[sink]
|
||||
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
|
||||
rsink = graph_rewrite(rsink, pm_reduce, ctx=ReduceContext(), name="remove reduce")
|
||||
rsink = graph_rewrite(rsink, pm_debuf, ctx=[0], name="debuf", bottom_up=True)
|
||||
rsink = graph_rewrite(rsink, sym, name="symbolic 2")
|
||||
|
||||
# renumber ranges
|
||||
#rngs = dedup([x for x in flatten([x.src[2:] for x in list(rsink.toposort())[::-1] if x.op is Ops.STORE]) if x.op is Ops.RANGE])
|
||||
#rsink = rsink.substitute({x:x.replace(arg=i) for i,x in enumerate(rngs)})
|
||||
|
||||
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
|
||||
rsink = apply_rewrites(rsink, rewrites_for_linearizer)
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
src = CStyleLanguage().render(rsink.arg.lst)
|
||||
print(src)
|
||||
return {sink:sink}
|
||||
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="split kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
kernel_assign: dict[UOp, UOp] = {}
|
||||
assign_rep: dict[UOp, UOp] = {}
|
||||
for u in tensor_map[sink].toposort():
|
||||
if u.op is not Ops.ASSIGN: continue
|
||||
kernel_assign[u.buf_uop] = u
|
||||
for s in u.src[1].src:
|
||||
# TODO: this is probably broken for MSELECT/MSTACK
|
||||
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
||||
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
|
||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep:
|
||||
tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
|
||||
return tensor_map
|
||||
|
|
@ -14,7 +14,7 @@ from tinygrad.device import Device, Buffer
|
|||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.schedule.kernelize import get_kernelize_map
|
||||
from tinygrad.schedule.rangeify import get_kernelize_map
|
||||
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
|
|
@ -252,7 +252,8 @@ class Tensor(MathTrait):
|
|||
# create the schedule
|
||||
schedule, var_vals = create_schedule_with_vars(sink)
|
||||
schedule = memory_planner(schedule)
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
|
||||
if (DEBUG >= 1 and len(schedule) >= 10) or (DEBUG >= 2 and len(schedule) > 1):
|
||||
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
|
||||
return schedule, var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
|
|
|
|||
|
|
@ -12,13 +12,15 @@ class Ops(FastEnum):
|
|||
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto() # noqa: E702
|
||||
|
||||
# track children
|
||||
CHILD = auto()
|
||||
CHILD = auto(); CHILDREN = auto() # noqa: E702
|
||||
MERGE = auto(); MBLOCK = auto(); INVALID = auto()
|
||||
|
||||
# buffer ops
|
||||
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
|
||||
BUFFERIZE = auto()
|
||||
|
||||
# blocks in linearizer (only used there)
|
||||
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702
|
||||
|
|
|
|||
|
|
@ -136,11 +136,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
@functools.cached_property
|
||||
def st(self) -> ShapeTracker|None:
|
||||
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
|
||||
if self.op is Ops.INDEX and self.src[0].op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER, Ops.BUFFERIZE}: return None
|
||||
if self.op is Ops.MBLOCK: return None
|
||||
if self.op in GroupOp.Block: return None
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# VIEW and MovementOps define a new ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
if self.op is Ops.BUFFERIZE: return ShapeTracker.from_shape((prod([r.vmax+1 for r in self.src[1:]]),))
|
||||
if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg)
|
||||
if self.op in GroupOp.Movement:
|
||||
if self.src[0].st is None: return None
|
||||
return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
# CONST with a DEVICE has a shape of ()
|
||||
if self.op is Ops.CONST and len(self.src) and self.src[0].op is Ops.DEVICE: return ShapeTracker.from_shape(())
|
||||
# BufferOps and ASSIGN flow ShapeTracker from a direct edge
|
||||
|
|
@ -158,7 +164,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
|
||||
|
||||
# otherwise we get the shape from sources
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None and x.op is not Ops.INDEX]): return None
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
match self.op:
|
||||
case Ops.MULTI: shape = tuple(self.src[0].shape[a]*len(self.device) if a == self.axis else s for a,s in enumerate(self.src[0].shape))
|
||||
|
|
@ -186,10 +192,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
@functools.cached_property
|
||||
def ranges(self) -> dict[UOp, None]:
|
||||
if self.op is Ops.RANGE: return {self:None}
|
||||
if self.op in {Ops.CONTIGUOUS, Ops.REDUCE, Ops.STORE}:
|
||||
if self.op is Ops.MERGE:
|
||||
ret = self.src[0].ranges.copy()
|
||||
for s in self.src[1::2]:
|
||||
if s in ret: del ret[s]
|
||||
for s in self.src[2::2]:
|
||||
ret.update(s.ranges)
|
||||
return ret
|
||||
if self.op in {Ops.BUFFERIZE, Ops.REDUCE}:
|
||||
ret = self.src[0].ranges.copy()
|
||||
for s in self.src[1:]:
|
||||
if s in ret: del ret[s]
|
||||
elif self.op in {Ops.STORE}:
|
||||
ret = self.src[0].ranges.copy()
|
||||
ret.update(self.src[1].ranges)
|
||||
for s in self.src[2:]:
|
||||
if s in ret: del ret[s]
|
||||
else:
|
||||
ret = {}
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
|
|
@ -293,6 +311,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||
def bufferize(self, *args, **kwargs): return UOp(Ops.BUFFERIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
|
||||
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
|
||||
|
|
@ -671,6 +690,7 @@ class UPat(MathTrait):
|
|||
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
|
||||
def fuse(self): return self.alu(Ops.FUSE)
|
||||
def or_broadcasted(self, **kwargs): return UPat.any(self, UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs))
|
||||
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
|
|||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
Ops.CHILD: "#80fff0"}
|
||||
Ops.CHILDREN: "#80ffc0", Ops.CHILD: "#80fff0", Ops.BUFFERIZE: "#FF991C"}
|
||||
|
||||
# VIZ API
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
for idx,x in enumerate(u.src):
|
||||
if x in excluded:
|
||||
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
|
||||
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}" + (f" {x.src[0].op}" if len(x.src) else "")
|
||||
else: label += f"\n{x.op.name}{idx} {x.arg}"
|
||||
try:
|
||||
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue