Compare commits

...

37 commits

Author SHA1 Message Date
George Hotz
c7ba7fcde1 ignore expand 2025-08-16 17:59:07 -07:00
George Hotz
86d7f7d224 junk 2025-08-16 17:41:18 -07:00
George Hotz
a5e6c7dbc9 do the store 2025-08-16 14:29:10 -07:00
George Hotz
707d8d9d72 prefetch 2025-08-16 14:00:14 -07:00
George Hotz
40767e8f92 rangeify td 2025-08-16 12:26:46 -07:00
George Hotz
41ab1e0852 clone movement ops 2025-08-16 09:31:14 -07:00
George Hotz
465eff25de work 2025-08-16 09:12:11 -07:00
George Hotz
57014d2302 testing backward 2025-08-16 08:52:13 -07:00
George Hotz
06fe3a2d57 no pcontig 2025-08-15 22:05:17 -07:00
George Hotz
778358eff6 fix bmnist 2025-08-15 19:12:14 -07:00
George Hotz
c4d565b591 improve names 2025-08-15 18:46:04 -07:00
George Hotz
6ef39f4657 mnist works 2025-08-15 18:37:59 -07:00
George Hotz
49cd68945c that seems to work 2025-08-15 18:35:23 -07:00
George Hotz
08ec4de6a3 beautiful mnist is close 2025-08-15 18:31:14 -07:00
George Hotz
7dc708e735 ops fixed 2025-08-15 18:05:54 -07:00
George Hotz
65cbf9d785 late children 2025-08-15 17:51:58 -07:00
George Hotz
b3f5852fae unbind_kernel 2025-08-15 17:19:36 -07:00
George Hotz
780a49ebfb symbolic in schedule 2025-08-15 16:43:13 -07:00
George Hotz
d1d4bbe179 contigs only 2025-08-15 16:18:28 -07:00
George Hotz
998775507d basic assign 2025-08-15 16:10:32 -07:00
George Hotz
14e055a8b6 progress counter 2025-08-15 15:52:26 -07:00
George Hotz
8418d06300 progress 2025-08-15 15:46:32 -07:00
George Hotz
4a6c2e5d68 new endrange solution 2025-08-15 15:44:31 -07:00
George Hotz
801712880e
Merge branch 'master' into lil_rangeify 2025-08-15 14:56:14 -07:00
George Hotz
20f8fe4443 progress children 2025-08-15 14:53:13 -07:00
George Hotz
6425ed93c8 more stuff passes 2025-08-15 14:37:53 -07:00
George Hotz
e48c87ba71 fix test_log_softmax 2025-08-15 14:13:47 -07:00
George Hotz
717b0d107c stuff passes 2025-08-15 12:37:19 -07:00
George Hotz
2b93b50710 fix rangeify tests 2025-08-15 12:08:52 -07:00
George Hotz
e4884845a3 flash attention is back 2025-08-15 11:18:45 -07:00
George Hotz
d040d55960
Merge branch 'master' into lil_rangeify 2025-08-15 11:08:50 -07:00
George Hotz
c707e87da0 fix rangeify 2025-08-15 10:19:50 -07:00
George Hotz
155e97045d ish 2025-08-15 10:11:28 -07:00
George Hotz
05f04bbcc3 work 2025-08-15 09:47:38 -07:00
George Hotz
551b34bb0f bufferize, don't use contig tag 2025-08-15 09:00:18 -07:00
George Hotz
0575c95389 bring that over 2025-08-15 08:54:33 -07:00
George Hotz
18130dae2a ** rangeify, try 3 2025-08-15 08:46:41 -07:00
10 changed files with 833 additions and 15 deletions

View file

@ -29,8 +29,7 @@ if __name__ == "__main__":
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss
return loss.realize(*opt.schedule_step())
@TinyJit
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100

113
test/test_rangeify.py Normal file
View file

@ -0,0 +1,113 @@
import unittest
from tinygrad import Tensor
class TestRangeify(unittest.TestCase):
def test_add(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
(A+B).realize()
def test_double_gemm(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(A@B@C).realize()
def test_double_gemm_exp(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).exp()@C).exp()).realize()
def test_double_gemm_relu(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).relu()@C).relu()).realize()
def test_double_gemm_relu_half_contig(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
(((A@B).relu().contiguous(arg=(1,))@C).relu()).realize()
def test_double_gemm_half_contig(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
((A@B).contiguous(arg=(1,))@C).realize()
def test_double_gemm_contig(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
((A@B).contiguous()@C).realize()
def test_many_gemm(self):
N = 1024
A = Tensor.empty(N, N)
B = Tensor.empty(N, N)
C = Tensor.empty(N, N)
D = Tensor.empty(N, N)
E = Tensor.empty(N, N)
F = Tensor.empty(N, N)
(A@B@C@D@E@F).realize()
def test_conv2d(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
x.conv2d(w1).realize()
def test_conv2d_t(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
(x*2).conv2d(w1).realize()
def test_double_conv2d(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).conv2d(w2).realize()
def test_double_conv2d_half_contig(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
# NOTE: this contiguous doesn't help
x.conv2d(w1).contiguous(arg=(1,)).conv2d(w2).permute(0,2,3,1).contiguous().realize()
def test_double_conv2d_contig(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).contiguous().conv2d(w2).realize()
def test_ffn(self):
from tinygrad.apps.llm import TransformerBlock
from tinygrad import nn
blk = TransformerBlock(1024, 4096, 1, 1, 1e-5)
for p in nn.state.get_parameters(blk): p.replace(Tensor.empty(p.shape))
x = Tensor.empty(128, 1024)
out = blk._feed_forward(x)
out.realize()
def test_flash_attention(self):
BS = 4
HEADS = 2
MATDIM = 16
EMB = 8
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
q.scaled_dot_product_attention(k, v).realize()
if __name__ == '__main__':
unittest.main()

View file

@ -1044,7 +1044,7 @@ class TestSchedule(unittest.TestCase):
k = Tensor.randn(32,8,16,8).realize()
v = Tensor.randn(32,8,16,8).realize()
out = Tensor.scaled_dot_product_attention(q,k,v)
run_schedule(check_schedule(out, 5))
#run_schedule(check_schedule(out, 5))
if getenv("CHECK", 1):
import torch
compare = torch.nn.functional.scaled_dot_product_attention(torch.tensor(q.numpy()),torch.tensor(k.numpy()),torch.tensor(v.numpy()))

View file

@ -100,12 +100,30 @@ class TestTiny(unittest.TestCase):
lambda x: x.flatten(1), nn.Linear(576, 10)]
# replace random weights with ones
Tensor.realize(*[p.replace(Tensor.ones_like(p).contiguous()) for p in nn.state.get_parameters(layers)])
for p in nn.state.get_parameters(layers): p.replace(Tensor.empty(p.shape))
# run model inference
probs = Tensor.rand(1, 1, 28, 28).sequential(layers).tolist()
probs = Tensor.empty(1, 1, 28, 28).sequential(layers).tolist()
self.assertEqual(len(probs[0]), 10)
# TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
@unittest.skipIf(IMAGE>0 or (CI and Device.DEFAULT == "DSP"), "failing because of make things that can't be images not images")
def test_mnist_backward(self):
# NOTE: we don't have the whole model here for speed
layers = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu]
# replace random weights with ones
# TODO: there's a bug here where it's tying two of the biases together. we need UNIQUE const
for p in nn.state.get_parameters(layers): p.replace(Tensor.empty(p.shape))
#for p in nn.state.get_parameters(layers): p.replace(Tensor.ones_like(p).contiguous().realize())
# realize gradients
for x in nn.state.get_parameters(layers): x.requires_grad_()
Tensor.empty(4, 1, 28, 28).sequential(layers).sum().backward()
Tensor.realize(*[x.grad for x in nn.state.get_parameters(layers) if x.grad is not None])
# *** image ***
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")

View file

@ -104,7 +104,9 @@ class CStyleLanguage(Renderer):
Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})",
Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})",
Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})",
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})"}
Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})", Ops.CMPEQ: lambda a,b,dtype: f"({a}=={b})",
# NOTE: these don't work, but they are nice for rendering
Ops.THREEFRY: lambda a,b,dtype: f"threefry({a},{b})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"}
string_rewrite = base_rewrite
extra_matcher = extra_pm

View file

@ -0,0 +1,663 @@
from typing import Any
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, AddrSpace, PtrDType
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, colored, flatten, dedup
from tinygrad.uop.symbolic import symbolic_simple, sym
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, KernelInfo, identity_element
imported_rewrites = PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
])
earliest_rewrites = imported_rewrites+PatternMatcher([
# RESHAPE on RESHAPE is the second reshape
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))),
# non shape changing RESHAPE is NOOP
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
# RESHAPE after COPY
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
# TODO: this should be BUFFER_VIEW
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
# const hacks
(UPat(Ops.CONST, name="x"), lambda x:
x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
# assign only to buffer
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}), UPat(name="x"))), lambda x: x if x.src[0].base.op is not Ops.BUFFER else None),
])
# 1. add contiguous where we have to
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
for s in rb.src:
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# realize parents of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
])
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS}, name="x"),
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
early_cleanups = PatternMatcher([(UPat().contiguous(name="c").contiguous(), lambda c: c),])
# 2. mark all children
@dataclass
class ChildrenContext: children: dict[UOp, list[UOp]]|None = None
def extract_children(ctx:ChildrenContext, x:UOp):
if ctx.children is not None: return
children_map = x.get_children_map()
ctx.children = {}
for k,v in children_map.items():
non_sink_children = [u for u in v if u.op is not Ops.SINK]
if len(non_sink_children) <= 1: continue
if any(x.op is Ops.REDUCE_AXIS for x in k.toposort()):
ctx.children[k] = non_sink_children
def mark_children(ctx:ChildrenContext, x:UOp):
new_srcs = [(UOp(Ops.CHILD, s.dtype, src=(UOp(Ops.CHILDREN, s.dtype, (s,), arg=len(ctx.children[s])),),
arg=(ctx.children[s].index(x), len(ctx.children[s]))) if s in ctx.children else s) for s in x.src]
return x.replace(src=tuple(new_srcs))
pm_children = PatternMatcher([
(UPat(Ops.SINK, name="x"), extract_children),
(UPat(GroupOp.All-{Ops.CHILD, Ops.CHILDREN}, name="x"), mark_children),
# hack for one kernel threefry
#(UPat(Ops.CHILD, src=(UPat(Ops.THREEFRY, name="x"),)), lambda x: x),
])
# 3. rangeify
@dataclass
class RangeifyContext:
idx: int = 0
regs: int = 0
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
seen_child: dict[UOp, Any] = field(default_factory=dict)
progress: int = 0
children: dict[UOp, list[UOp]]|None = None
def map_reshape(idx:UOp, r:UOp):
acc = 1
to_sum = []
for s,src in list(zip(idx.shape, idx.src[1:]))[::-1]:
to_sum.append(acc*src)
acc *= s
mish = sum(to_sum)
ret = []
for s in r.src[0].shape[::-1]:
if resolve(s!=1):
# this MOD should limit any ranges outside s
ret.append(mish % s)
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
return r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg)
def map_pad(idx:UOp, r:UOp):
ret = list(idx.src[1:])
bigwhere = UOp.const(dtypes.bool, True)
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
if s == 0 and e == 0: continue
where = UOp.const(dtypes.bool, True)
if e > 0: where = where & (ret[i] < (sh-e))
if s > 0: where = where & (ret[i] >= s)
bigwhere = bigwhere & where
# this is safe but dumb
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
# PAD is with 0
return bigwhere.simplify().where(r.src[0].index(*ret, dtype=idx.dtype, arg=idx.arg), UOp.const(r.dtype, 0))
def map_expand(r:UOp, idx:UOp):
new_rngs = []
ending_ranges = []
non_ending_ranges = []
for a,x,y in zip(idx.src[1:], r.src[0].shape, r.shape):
axis_to_range = [u for u in a.toposort() if u.op is Ops.RANGE]
if resolve(x!=y, False):
ending_ranges.extend(axis_to_range)
new_rngs.append(a.const_like(0))
else:
non_ending_ranges.extend(axis_to_range)
new_rngs.append(a)
ending_ranges = [x.arg for x in ending_ranges if x not in non_ending_ranges]
if idx.arg is not None: ending_ranges.append(idx.arg)
return r.src[0].index(*new_rngs, arg=min([x for x in ending_ranges]) if ending_ranges else None)
pm_mops = PatternMatcher([
# this is like the definitions of these
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)),
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)),
# expand needs to end ranges
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="idx"), map_expand),
# reshape does a lot of symbolic stuff
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="idx"), map_reshape),
# pad adds min and max
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="idx"), map_pad),
])
def map_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp|None=None):
# NOTE: partial contig is disabled for now
#arg = x.arg
arg = None
if arg is None and idx is not None: return None
if arg is not None and idx is None: return None
ranges = []
new_ranges = []
passthrough_idx = []
for i,s in enumerate(x.shape):
if arg is not None and i not in arg:
assert idx is not None, "partial contig requires index"
ranges.append(idx.src[1+i])
continue
if idx is not None: passthrough_idx.append(idx.src[1+i])
if resolve(s!=1):
ranges.append(UOp.range(dtypes.int, s, ctx.idx))
new_ranges.append(ranges[-1])
ctx.idx += 1
else:
ranges.append(UOp.const(dtypes.int, 0))
ret = x.src[0].index(*ranges).bufferize(*new_ranges, arg=x.device)
ret = ret.index(*passthrough_idx) if len(passthrough_idx) else ret.reshape(x.shape)
return ret
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
rngs = list(idx.src[1:])
new_ranges = []
for i,s in enumerate(red.src[0].shape):
if i in red.arg[1]:
rngs[i] = UOp.range(dtypes.int, s, ctx.idx)
ctx.idx += 1
new_ranges.append(rngs[i])
return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0])
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
if c not in ctx.seen_children: ctx.seen_children[c] = {}
# wait here until we have seen all the children
if len(ctx.seen_children[c]) != x.arg[1]:
ctx.progress += 1
if ctx.progress > 10000: raise RuntimeError("children not making progress")
# NOTE: we mark this here
ctx.seen_children[c][x.arg[0]] = idx
raise RewriteNotReady
ctx.progress = 0
if c not in ctx.seen_child:
all_rngs = zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()])
out_rngs = []
end_ranges = []
idx_ranges = []
for i,r in enumerate(all_rngs):
if all_same(r):
out_rngs.append(r[0])
else:
out_rngs.append(UOp.range(dtypes.int, c.shape[i], ctx.idx))
ctx.idx += 1
end_ranges.append(out_rngs[-1])
idx_ranges.append(i)
ctx.seen_child[c] = (idx_ranges, end_ranges)
else:
out_rngs = list(idx.src[1:])
idx_ranges, end_ranges = ctx.seen_child[c]
for i,nr in zip(idx_ranges, end_ranges): out_rngs[i] = nr
if len(idx_ranges) == 0: return c.index(*out_rngs)
# NOTE: partial contigs can still come from here
return c.index(*out_rngs).bufferize(*end_ranges, arg=x.device).index(*[idx.src[1+i] for i in idx_ranges])
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now")
return idx.replace(src=(idx.src[0].src[0],)+idx.src[1:])
def might_end_axis(idx:UOp):
if idx.arg is None: return None
to_end_axis = []
for i,a in enumerate(idx.src[1:]):
if any(x.arg > idx.arg for x in a.toposort() if x.op is Ops.RANGE):
to_end_axis.append(i)
if to_end_axis: return idx.replace(src=(idx.src[0].contiguous(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
return idx.replace(arg=None)
pm_rangeify = pm_mops+PatternMatcher([
# sink contigs to kick it off
(UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x"), map_contiguous),
# if there are new ended children, tag the SINK
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
(UPat(Ops.INDEX, src=(UPat(Ops.CHILDREN, name="c"),), allow_any_len=True, name="idx"), children_gate),
# if we come across this, remove it. it was a CHILD unused in an INDEX
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
# if there's an INDEX it can support partial contig
(UPat(Ops.INDEX, src=(UPat(Ops.CONTIGUOUS, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_contiguous),
# CONST can't have axes. remove srcs when we idx
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),
# handle arg on any op with weight. old endrange stuff
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
])
# 4. remove bufferize
def bufferize_to_store(ctx, x:UOp):
rngs = x.src[1:]
shape = tuple([r.vmax+1 for r in rngs])
assert prod(shape) > 0, f"no zero sized buffers {shape}"
store_rngs = [x for x in UOp.sink(*rngs).toposort() if x.op is Ops.RANGE]
if x.src[0].op is Ops.ASSIGN:
return x.src[0].src[0].replace(dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0].src[1], *store_rngs)
#buf = UOp.new_buffer(x.arg, prod(shape), x.dtype)
buf = UOp(Ops.DEFINE_LOCAL, x.dtype.ptr(size=prod(shape)), arg=ctx[0])
ctx[0] += 1
return buf.reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=prod(shape))).store(x.src[0], *store_rngs)
def add_load_on_buffer(idx:UOp, b:UOp):
if isinstance(idx.dtype, PtrDType): return None
return idx.replace(dtype=idx.dtype.ptr(b.size), arg=None).load()
def add_load_on_store(x:UOp, st:UOp):
if isinstance(x.dtype, PtrDType): return None
rngs = x.src[1:]
shape = tuple([r.vmax+1 for r in rngs])
b = st.src[0].src[0]
#assert b.op is Ops.BUFFER
return b.shrink(((0,prod(shape)),)).reshape(shape).index(*rngs, dtype=x.dtype.ptr(size=b.size)).load(st)
def shp(shp, rng):
acc = 1
ss = []
for s,r in list(zip(shp,rng))[::-1]:
ss.append(r*acc)
acc *= s
return sum(ss)
pm_add_buffers = pm_mops+PatternMatcher([
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
(UPat(Ops.INDEX, src=(UPat(Ops.BUFFER, name="b"), UPat()), name="idx"), add_load_on_buffer),
(UPat(Ops.INDEX, src=(UPat(Ops.STORE, name="st"),), allow_any_len=True, name="x"), add_load_on_store),
# HACK
#(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
(UPat(Ops.INDEX, name="idx").contiguous(),
lambda idx: UOp.new_buffer(idx.device, prod(idx.arg), idx.dtype).index(shp(idx.arg, idx.src[1:]),
dtype=idx.dtype.ptr(prod(idx.arg))).store(*idx.src))
])
# 5 (alt). create pointers
def debuf(ctx, b:UOp):
ret = UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx[0])
ctx[0] += 1
return ret
pm_debuf = PatternMatcher([
(UPat(Ops.BUFFER, name="b"), debuf),
# HACK: consts shouldn't have srcs by here
(UPat(Ops.CONST, name="x"), lambda x: x.replace(src=()) if len(x.src) else None),
# no movement ops
(UPat(GroupOp.Movement, name="x"), lambda x: x.src[0]),
# HACK: no copy
(UPat(Ops.COPY, name="x"), lambda x: x.src[0]),
])
# 5. split into kernels
@dataclass
class LocalAddBufferContext:
dg:int = 0
map:dict = field(default_factory=dict)
vars:dict = field(default_factory=dict)
def debuf(ctx:LocalAddBufferContext, b:UOp): return UOp(Ops.DEFINE_GLOBAL, b.dtype.ptr(b.arg), arg=ctx.map[b][1])
def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
ctx.vars[b] = None
return b.src[0]
def split_load(ctx:LocalAddBufferContext, s:UOp):
b = s.src[0].src[0]
if b.op is not Ops.BUFFER: return None
if len(s.src) == 2 and s.src[1].op is Ops.ASSIGN:
assert len(s.src) == 2
lb = s.src[1]
assert b not in ctx.map or ctx.map[b][0] == lb
else:
lb = b
if b not in ctx.map:
ctx.map[b] = (lb, ctx.dg)
ctx.dg += 1
return s.replace(src=s.src[0:1]) if b is not lb else None
def handle_store(ctx:LocalAddBufferContext, s:UOp):
b = s.src[0].src[0]
if b.op is not Ops.BUFFER: return None
if b not in ctx.map:
ctx.map[b] = (b, ctx.dg)
ctx.dg += 1
if s.src[1].op is Ops.COPY: return s.src[1]
return None
to_define_global = PatternMatcher([
(UPat(Ops.BUFFER, name="b"), debuf),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat(Ops.LOAD, name="s"), split_load),
(UPat(Ops.STORE, name="s"), handle_store),
])
def split_store(x:UOp):
if len(x.ranges): return None
store_rngs = x.src[2:]
ctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global, ctx=ctx, name="kernel split", bottom_up=True)
rng = sorted([u for u in ret.toposort() if u.op is Ops.RANGE], key=lambda x: x.arg)
name = "k"+colored('_', 'BLACK').join(['']+[colored(str(s.vmax+1), "WHITE") if s in store_rngs else colored(str(s.vmax+1), "red") for s in rng])
ret = ret.sink(arg=KernelInfo(name=name)) if ret.op is Ops.STORE else ret
kernel = UOp(Ops.KERNEL, src=tuple([x[0] for x in ctx.map.values()])+tuple(ctx.vars.keys()), arg=Kernel(ret, ()))
return kernel.src[0].assign(kernel)
split_kernels = PatternMatcher([
(UPat(Ops.STORE, name="x"), split_store),
])
pm_children_fixup = PatternMatcher([
# clone all movement ops
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat(GroupOp.Movement, name="m"),)),), name="c"),
lambda c,m: UOp(m.op, m.dtype, (c.replace(src=(c.src[0].replace(src=(m.src[0],)),)),), m.arg)),
])
@dataclass
class RContext:
range_num = 0
def new_range(ctx, s):
ret = UOp.range(dtypes.int, s, ctx.range_num)
ctx.range_num += 1
return ret
def td_reshape(ctx, idx:UOp, r:UOp):
acc = 1
to_sum = []
for s,i in list(zip(idx.arg, idx.src[1:]))[::-1]:
to_sum.append(i*acc)
acc *= s
mish = sum(to_sum)
ret = []
for s in r.arg[::-1]:
if resolve(s!=1):
# this MOD should limit any ranges outside s
ret.append(mish % s)
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
ii = idx.src[0]
out_rng = ret
"""
out_rng = []
for i,rr in enumerate(ret):
if rr.op not in {Ops.RANGE, Ops.CONST}:
out_rng.append(new_range(ctx, r.arg[i]))
else:
out_rng.append(rr)
mm = [idx.src[0]]
for x,y in zip(ret, out_rng):
if x is not y:
mm.append(x)
mm.append(y)
if len(mm) > 1:
ii = UOp(Ops.MERGE, idx.dtype, tuple(mm))
"""
return ii.index(*out_rng, dtype=idx.dtype, arg=r.arg)
def td_elementwise(ctx, e:UOp):
# if the range is closed by a reduce to the left, we can't reuse it
# TODO: handle composite ranges better
reduces_left = flatten([x.src[1:] for x in e.toposort() if x.op is Ops.REDUCE])
shps = [u.arg for u in e.src]
assert all_same(shps)
rngs = [u.src[1:] for u in e.src]
out_rng = []
need_merge = False
for i,r in enumerate(zip(*rngs)):
r = [x for x in r if x is not UOp.const(dtypes.int, 0)]
if len(r) == 0:
out_rng.append(UOp.const(dtypes.int, 0))
elif all_same(r) and r[0] not in reduces_left:
out_rng.append(r[0])
else:
out_rng.append(new_range(ctx, shps[0][i]))
need_merge = True
if need_merge:
new_src = []
for u in e.src:
assert u.op is Ops.INDEX
out = [u.src[0]]
rngs_in_src = [x for x in out[0].toposort() if x.op is Ops.RANGE]
for i,idx in list(enumerate(u.src[1:]))[::-1]:
rngs_in_idx = [x for x in idx.toposort() if x.op is Ops.RANGE]
if all(x not in rngs_in_src for x in rngs_in_idx):
# for expands
continue
if idx is not out_rng[i]:
out.append(idx)
out.append(out_rng[i])
#out = UOp(Ops.MERGE, out.dtype, src=(out, idx, out_rng[i]))
if len(out) > 1:
new_src.append(UOp(Ops.MERGE, u.dtype, tuple(out)))
else:
new_src.append(out[0])
#mm = []
#for i,idx in enumerate(u.src[1:]):
# if idx is not out_rng[i] and idx is not UOp.const(dtypes.int, 0):
# mm.append(UOp(Ops.MERGE, src=(idx, out_rng[i])))
#new_src.append(UOp(Ops.MBLOCK, u.dtype, (u.src[0],)+tuple(mm)))
else:
new_src = list([x.src[0] for x in e.src])
return e.replace(src=tuple(new_src)).index(*out_rng, arg=shps[0])
def td_shrink(idx:UOp, r:UOp):
ret = []
shp = []
for u,(s,e),shape in zip(idx.src[1:], r.arg, idx.arg):
assert s == 0
#if u.vmax >= e: u = (u<e).where(u, UOp(Ops.INVALID, u.dtype))
ret.append(u)
shp.append(min(shape, e))
return idx.src[0].index(*ret, dtype=idx.dtype, arg=tuple(shp))
def td_reduce(ctx, idx:UOp, r:UOp):
rngs = idx.src[1:]
new_shp = tuple([s if i not in r.arg[1] else 1 for i,s in enumerate(idx.arg)])
return UOp(Ops.REDUCE, r.dtype, (idx.src[0],)+tuple([x for i,x in enumerate(rngs) if i in r.arg[1]]),
r.arg[0]).index(*[x if i not in r.arg[1] else UOp.const(dtypes.int, 0) for i,x in enumerate(rngs)], arg=new_shp)
pm_td_rangeify = PatternMatcher([
#(UPat(Ops.INDEX, src=(UPat(Ops.MERGE, src=(UPat(Ops.LOAD, name="b"),), allow_any_len=True),), allow_any_len=True, name="idx"),
# lambda idx,b: b.src[0].src[0].index(*idx.src[1:], dtype=b.src[0].dtype).load().index(*idx.src[1:], arg=idx.arg)),
(UPat(Ops.BUFFER, name="b"), lambda ctx, b:
b.replace(tag=1).index(nr:=new_range(ctx, b.size), dtype=b.dtype.ptr(size=b.size)).load().index(nr, arg=(b.size,)) if b.tag is None else None),
#b.replace(tag=1).index(new_range(ctx, b.size), arg=(b.size,)) if b.tag is None else None),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="c"), lambda c: c.replace(src=()).index(arg=())),
(UPat(Ops.RESHAPE, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_reshape),
(UPat(Ops.SHRINK, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_shrink),
(UPat(Ops.PERMUTE, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
lambda r,idx: idx.src[0].index(*[idx.src[1+p] for p in r.arg], dtype=idx.dtype, arg=tuple(idx.arg[p] for p in r.arg))),
# 0s are already in place for EXPAND
#(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"), lambda r,idx: idx.replace(arg=r.arg)),
(UPat(Ops.EXPAND, src=(UPat(Ops.INDEX, name="idx"),), name="r"),
lambda ctx,r,idx: idx.src[0].index(*[ii if s1==s2 else new_range(ctx, s1) for s1,s2,ii in zip(r.arg, idx.arg, idx.src[1:])], arg=r.arg)),
(UPat(GroupOp.Elementwise, src=UPat(Ops.INDEX), name="e"), td_elementwise),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.INDEX, name="idx"),), name="r"), td_reduce),
])
def remove_merge(m):
tr0, tr1 = [], []
for r0,r1 in zip(m.src[1::2], m.src[2::2]):
if r0 is r1: continue
tr0.append(r0)
tr1.append(r1)
if m.src[0].op is Ops.LOAD and False:
# hack for LOAD
reps = {k:v for k,v in zip(tr0, tr1)}
return m.src[0].substitute(reps)
return UOp(Ops.BUFFERIZE, m.dtype, (m.src[0],)+tuple(tr0), arg=m.device).index(*tr1)
no_merge = PatternMatcher([
(UPat(Ops.MERGE, name="m"), remove_merge),
])
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
tensor_map = graph_rewrite_map(sink, earliest_rewrites, name="earliest")
realize_map = {}
graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph")
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add contiguous")
tensor_map = graph_rewrite_map(tensor_map[sink], early_cleanups+remove_tags, input_map=tensor_map, name="cleanup")
rsink = tensor_map[sink]
ctx = RContext()
rsink = graph_rewrite(rsink, pm_td_rangeify, ctx=ctx, name="td rangeify")
rsink = graph_rewrite(rsink, sym, name="symbolic")
# find MOD on RANGE to split
while 1:
#break
reps = {}
for u in rsink.toposort():
if u.op is Ops.MOD and u.src[0].op is Ops.RANGE and u.src[1].op is Ops.CONST:
r = u.src[0].vmax+1
c = u.src[1].arg
if r%c == 0:
reps[u.src[0]] = new_range(ctx, r//c)*c + new_range(ctx, c)
print(len(reps))
if len(reps) == 0: break
rsink = rsink.substitute(reps)
rsink = graph_rewrite(rsink, sym, name="symbolic")
for i in range(0):
print("loop")
real_rngs = rsink.ranges.copy()
for u in rsink.toposort():
if u.op is Ops.REDUCE:
for s in u.src[1:]: real_rngs[s] = None
real_rngs = {x:[] for x in real_rngs}
print("unmovable", [x.arg for x in real_rngs])
for u in rsink.toposort():
if u.op is not Ops.MERGE: continue
assert all(x.op is Ops.RANGE for x in u.src)
r0, r1 = [x for x in u.src]
if r0 is r1: continue
if r0 in real_rngs: real_rngs[r0].append(r1)
if r1 in real_rngs: real_rngs[r1].append(r0)
rew = {}
for k,v in real_rngs.items():
print(k.arg, [x.arg for x in v])
for u in v:
rew[u] = k
rsink = rsink.substitute(rew)
"""
rngs = [x for x in rsink.toposort() if x.op is Ops.RANGE]
mmap = {16:1000, 2:8, 3:9}
rep = {}
for x in rngs:
if x.arg in mmap:
rep[x] = x.replace(arg=mmap[x.arg])
rsink = rsink.substitute(rep)
"""
rsink = graph_rewrite(rsink, no_merge, name="remove merge")
"""
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="children")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children_fixup, bottom_up=True, input_map=tensor_map, name="fixup children")
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="rangeify")
"""
#tensor_map = graph_rewrite_map(tensor_map[sink], symbolic_simple, input_map=tensor_map, name="symbolic")
#tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, bottom_up=True, input_map=tensor_map, name="add buffers")
#if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Rangeify Graph")
if getenv("VIZ"): graph_rewrite(rsink, PatternMatcher([]), name="View Rangeify Graph")
rsink = graph_rewrite(rsink, pm_add_buffers, ctx=[0], bottom_up=True, name="add buffers")
# render
if getenv("SRC") or True:
#rsink = tensor_map[sink]
from tinygrad.codegen.devectorizer import pm_reduce, ReduceContext
rsink = graph_rewrite(rsink, pm_reduce, ctx=ReduceContext(), name="remove reduce")
rsink = graph_rewrite(rsink, pm_debuf, ctx=[0], name="debuf", bottom_up=True)
rsink = graph_rewrite(rsink, sym, name="symbolic 2")
# renumber ranges
#rngs = dedup([x for x in flatten([x.src[2:] for x in list(rsink.toposort())[::-1] if x.op is Ops.STORE]) if x.op is Ops.RANGE])
#rsink = rsink.substitute({x:x.replace(arg=i) for i,x in enumerate(rngs)})
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
rsink = apply_rewrites(rsink, rewrites_for_linearizer)
from tinygrad.renderer.cstyle import CStyleLanguage
src = CStyleLanguage().render(rsink.arg.lst)
print(src)
return {sink:sink}
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
assign_rep: dict[UOp, UOp] = {}
for u in tensor_map[sink].toposort():
if u.op is not Ops.ASSIGN: continue
kernel_assign[u.buf_uop] = u
for s in u.src[1].src:
# TODO: this is probably broken for MSELECT/MSTACK
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep:
tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
return tensor_map

View file

@ -14,7 +14,7 @@ from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.schedule.kernelize import get_kernelize_map
from tinygrad.schedule.rangeify import get_kernelize_map
# *** all in scope Tensors are here. this gets relevant UOps ***
@ -252,7 +252,8 @@ class Tensor(MathTrait):
# create the schedule
schedule, var_vals = create_schedule_with_vars(sink)
schedule = memory_planner(schedule)
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
if (DEBUG >= 1 and len(schedule) >= 10) or (DEBUG >= 2 and len(schedule) > 1):
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms")
return schedule, var_vals
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:

View file

@ -12,13 +12,15 @@ class Ops(FastEnum):
NOOP = auto(); SINK = auto(); UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); PRECAST = auto() # noqa: E702
# track children
CHILD = auto()
CHILD = auto(); CHILDREN = auto() # noqa: E702
MERGE = auto(); MBLOCK = auto(); INVALID = auto()
# buffer ops
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
# ops that adjust the behavior of the scheduler
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
BUFFERIZE = auto()
# blocks in linearizer (only used there)
BLOCK = auto(); BLOCKSTART = auto(); BLOCKEND = auto(); BLOCKFINAL = auto() # noqa: E702

View file

@ -136,11 +136,17 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def st(self) -> ShapeTracker|None:
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
if self.op is Ops.INDEX and self.src[0].op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.BUFFER, Ops.BUFFERIZE}: return None
if self.op is Ops.MBLOCK: return None
if self.op in GroupOp.Block: return None
from tinygrad.shape.shapetracker import ShapeTracker
# VIEW and MovementOps define a new ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
if self.op is Ops.BUFFERIZE: return ShapeTracker.from_shape((prod([r.vmax+1 for r in self.src[1:]]),))
if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg)
if self.op in GroupOp.Movement:
if self.src[0].st is None: return None
return unwrap(self.src[0].st).mop(self.op, self.arg)
# CONST with a DEVICE has a shape of ()
if self.op is Ops.CONST and len(self.src) and self.src[0].op is Ops.DEVICE: return ShapeTracker.from_shape(())
# BufferOps and ASSIGN flow ShapeTracker from a direct edge
@ -158,7 +164,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
# otherwise we get the shape from sources
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
if not (src_sts := [x.st for x in self.src if x.st is not None and x.op is not Ops.INDEX]): return None
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
match self.op:
case Ops.MULTI: shape = tuple(self.src[0].shape[a]*len(self.device) if a == self.axis else s for a,s in enumerate(self.src[0].shape))
@ -186,10 +192,22 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def ranges(self) -> dict[UOp, None]:
if self.op is Ops.RANGE: return {self:None}
if self.op in {Ops.CONTIGUOUS, Ops.REDUCE, Ops.STORE}:
if self.op is Ops.MERGE:
ret = self.src[0].ranges.copy()
for s in self.src[1::2]:
if s in ret: del ret[s]
for s in self.src[2::2]:
ret.update(s.ranges)
return ret
if self.op in {Ops.BUFFERIZE, Ops.REDUCE}:
ret = self.src[0].ranges.copy()
for s in self.src[1:]:
if s in ret: del ret[s]
elif self.op in {Ops.STORE}:
ret = self.src[0].ranges.copy()
ret.update(self.src[1].ranges)
for s in self.src[2:]:
if s in ret: del ret[s]
else:
ret = {}
for s in self.src: ret.update(s.ranges)
@ -293,6 +311,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def contiguous(self, *args, **kwargs): return UOp(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
def bufferize(self, *args, **kwargs): return UOp(Ops.BUFFERIZE, dtype=self.dtype, src=(self,)+args, **kwargs)
def fuse(self): return self.alu(Ops.FUSE)
def allreduce(self, op, device:str|tuple[str, ...]|UOp):
assert isinstance(self.device, tuple), f"allreduce must be on tuple {self.device} isn't"
@ -671,6 +690,7 @@ class UPat(MathTrait):
def reduce(self, *src:UPat, **kwargs): return UPat(Ops.REDUCE, self.dtype, src=(self,)+src, **kwargs)
def fuse(self): return self.alu(Ops.FUSE)
def or_broadcasted(self, **kwargs): return UPat.any(self, UPat(Ops.VECTORIZE, self.dtype, src=self, **kwargs))
def contiguous(self, *args, **kwargs): return UPat(Ops.CONTIGUOUS, dtype=self.dtype, src=(self,)+args, **kwargs)
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):

View file

@ -19,7 +19,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
Ops.CHILD: "#80fff0"}
Ops.CHILDREN: "#80ffc0", Ops.CHILD: "#80fff0", Ops.BUFFERIZE: "#FF991C"}
# VIZ API
@ -73,7 +73,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
for idx,x in enumerate(u.src):
if x in excluded:
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}"
if x.op is Ops.CONST and dtypes.is_float(u.dtype): label += f"\nCONST{idx} {x.arg:g}" + (f" {x.src[0].op}" if len(x.src) else "")
else: label += f"\n{x.op.name}{idx} {x.arg}"
try:
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None: