mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
24 commits
master
...
kernelless
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6945b90b0 | ||
|
|
dbd95a71dd | ||
|
|
969680a246 | ||
|
|
dce989ed94 | ||
|
|
11abcbd182 |
||
|
|
edf6caae81 |
||
|
|
96869abaf0 | ||
|
|
af6481eaba | ||
|
|
b5ad66cdf2 | ||
|
|
7976a01d59 | ||
|
|
6f516e46c6 | ||
|
|
439ada1ada | ||
|
|
c342809b3e | ||
|
|
b4ab6de416 | ||
|
|
3d31b0b5f6 | ||
|
|
5bab842337 | ||
|
|
427f773bc2 | ||
|
|
45a09207f9 | ||
|
|
494e951e90 | ||
|
|
c4410e91fd | ||
|
|
905019a4ec | ||
|
|
149c3f8fe9 | ||
|
|
5706e2d845 | ||
|
|
c66d2082c6 |
15 changed files with 357 additions and 30 deletions
|
|
@ -743,8 +743,9 @@ class TestFloat4(unittest.TestCase):
|
|||
len([uop for uop in uops if uop.op is Ops.STORE and uop.src[1].dtype == dtypes.half.vec(4)]))
|
||||
|
||||
def test_float4_basic(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
b = Tensor.empty(2, 8).realize()
|
||||
# NOTE: this used to fuse from (2, 8)
|
||||
a = Tensor.empty(16).realize()
|
||||
b = Tensor.empty(16).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
|
|
|
|||
|
|
@ -160,6 +160,7 @@ class TestOps(unittest.TestCase):
|
|||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
|
||||
helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True)
|
||||
|
||||
@unittest.skip("this is undefined, right?")
|
||||
def test_empty_0(self):
|
||||
helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True)
|
||||
|
||||
|
|
|
|||
41
test/test_rangeify.py
Normal file
41
test/test_rangeify.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.uop.ops import KernelInfo
|
||||
from tinygrad.opt.kernel import Opt, OptOps
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
def with_opts(c:Tensor, opts_to_apply:list[Opt]):
|
||||
s = c.schedule()[-1]
|
||||
program = get_program(s.ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply))), Device.default.renderer)
|
||||
print(program.src)
|
||||
|
||||
class TestRangeify(unittest.TestCase):
|
||||
def test_dont_upcast(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
c = a + b
|
||||
with_opts(c, [])
|
||||
|
||||
def test_upcast(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
c = a + b
|
||||
with_opts(c, [Opt(op=OptOps.UPCAST, axis=1, arg=4)])
|
||||
|
||||
def test_upcast_sum(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a.sum(axis=1)
|
||||
with_opts(b, [Opt(op=OptOps.UPCAST, axis=0, arg=4)])
|
||||
|
||||
def test_unroll_sum(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a.sum(axis=1)
|
||||
with_opts(b, [Opt(op=OptOps.UNROLL, axis=0, arg=4)])
|
||||
|
||||
def test_both_sum(self):
|
||||
a = Tensor.empty(4, 4)
|
||||
b = a.sum(axis=1)
|
||||
with_opts(b, [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -30,9 +30,19 @@ class TestTiny(unittest.TestCase):
|
|||
def test_gemm(self, N=64, out_dtype=dtypes.float):
|
||||
a = Tensor.ones(N,N).contiguous()
|
||||
b = Tensor.eye(N).contiguous()
|
||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
self.assertListEqual((out:=a@b).contiguous().flatten().tolist(), [1.0]*(N*N))
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
def test_eye(self):
|
||||
a = Tensor.eye(4, dtype=dtypes.int)
|
||||
self.assertListEqual(a.tolist(), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
|
||||
|
||||
def test_conv(self, N=32):
|
||||
a = Tensor.ones(1,4,N,N).contiguous()
|
||||
w1 = Tensor.ones(16,4,3,3).contiguous()
|
||||
out = a.conv2d(w1)
|
||||
self.assertTrue(all([x == 36.0 for x in out.contiguous().flatten().tolist()]))
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
def test_random(self):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tinygrad.uop.spec import type_verify
|
|||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.rangeify import pm_rangeify, pm_name, RangeifyContext
|
||||
from tinygrad.codegen.lowerer import pm_lowerer, get_index
|
||||
from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
|
|
@ -43,7 +44,9 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
|||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
#ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
ret.append(RewriteStep(pm_rangeify, lambda _: RangeifyContext(), name="rangeify", bottom_up=True))
|
||||
ret.append(RewriteStep(pm_name, lambda _: [0], name="name"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="initial symbolic"))
|
||||
|
|
|
|||
|
|
@ -260,6 +260,8 @@ pm_render = PatternMatcher([
|
|||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
||||
len(store.src) <= 2 or store.src[2].op != Ops.IF else None),
|
||||
# TODO: CONST shouldn't have src
|
||||
(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
])
|
||||
|
||||
# *** Ops.REDUCE -> Ops.DEFINE_ACC ***
|
||||
|
|
@ -277,7 +279,7 @@ def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
|
|||
return [inp]
|
||||
|
||||
def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
inp, reduce_range = red.src[0], red.src[1:]
|
||||
inp, acc, reduce_range = red.src[0], red.src[1], red.src[2:]
|
||||
lst = horizontal_reduce(inp, red.dtype)
|
||||
assert all(x.dtype == red.dtype for x in lst), f"horizontal reduction mismatch {lst[0].dtype} != {red.dtype}"
|
||||
# if we have a range
|
||||
|
|
@ -285,8 +287,8 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
|||
topo = inp.toposort()
|
||||
stored_ranges = flatten([x.src[2:] for x in topo if x.op is Ops.STORE])
|
||||
input_ranges = tuple([x for x in topo if x.op is Ops.RANGE and x not in reduce_range and x not in stored_ranges])
|
||||
identity = red.const_like(identity_element(red.arg, red.dtype.scalar()))
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
identity = UOp.const(red.dtype, identity_element(red.arg, red.dtype.scalar()))
|
||||
#acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=1, addrspace=AddrSpace.REG), arg=(ctx.acc_num,)).index(UOp.const(dtypes.int, 0))
|
||||
do_store = acc.store(identity, UOp(Ops.NOOP, src=input_ranges)) if len(input_ranges) else acc.store(identity)
|
||||
lst = [acc.load(do_store, *reduce_range)] + lst # put acc as the first element
|
||||
ctx.acc_num += 1
|
||||
|
|
@ -370,7 +372,7 @@ def reduce_collapse(red:UOp):
|
|||
|
||||
def reduce_unparented(red:UOp):
|
||||
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
|
||||
reduce_parented, reduce_unparented = partition(red.src[2:], lambda x: x in red.src[0].sparents)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
|
||||
if red.arg is Ops.ADD:
|
||||
|
|
|
|||
249
tinygrad/codegen/rangeify.py
Normal file
249
tinygrad/codegen/rangeify.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, KernelInfo, GroupOp, AxisType, TRACK_MATCH_STATS, identity_element
|
||||
from tinygrad.opt.kernel import axis_colors, Opt, OptOps
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.helpers import argsort, colored, prod, all_same, getenv
|
||||
|
||||
@dataclass
|
||||
class RangeifyContext:
|
||||
idx: int = 0
|
||||
regs: int = 0
|
||||
opts: tuple[Opt, ...] = ()
|
||||
|
||||
def map_store(ctx:RangeifyContext, x:UOp):
|
||||
if x.tag == 1: return None
|
||||
ranges = []
|
||||
for i,s in enumerate(x.shape):
|
||||
upcast_amount = prod([o.arg if o.arg != 0 else s for o in ctx.opts if o.axis == i and o.op == OptOps.UPCAST])
|
||||
if resolve(s!=1):
|
||||
if upcast_amount != 1:
|
||||
assert s%upcast_amount == 0
|
||||
rng = UOp.range(dtypes.int, s//upcast_amount, (ctx.idx, AxisType.LOOP)) * upcast_amount
|
||||
rng = rng + UOp.range(dtypes.int, upcast_amount, (ctx.idx+1, AxisType.UPCAST))
|
||||
ranges.append(rng)
|
||||
ctx.idx += 2
|
||||
else:
|
||||
ranges.append(UOp.range(dtypes.int, s, (ctx.idx, AxisType.LOOP)))
|
||||
ctx.idx += 1
|
||||
else:
|
||||
ranges.append(UOp.const(dtypes.int, 0))
|
||||
mm = UOp(Ops.INDEX, dtype=x.src[0].dtype, src=(x.src[0],)+tuple(ranges))
|
||||
mm2 = UOp(Ops.INDEX, dtype=x.src[0].dtype, src=(x.src[1],)+tuple(ranges))
|
||||
return UOp(Ops.STORE, src=(mm, mm2)+tuple([x for x in UOp.sink(*ranges).toposort() if x.op is Ops.RANGE]), tag=1)
|
||||
|
||||
def map_load(ctx:RangeifyContext, idx:UOp, load:UOp):
|
||||
out_ranges = idx.src[1:]
|
||||
idx_sink = UOp.sink(*out_ranges)
|
||||
upcast_ranges = [x for x in idx_sink.toposort() if x.op is Ops.RANGE and x.arg[1] in (AxisType.UPCAST, AxisType.UNROLL)]
|
||||
upcast_shape = tuple([x.vmax+1 for x in upcast_ranges])
|
||||
if len(upcast_ranges):
|
||||
buf = UOp(Ops.DEFINE_REG, load.dtype.ptr(size=prod([x.vmax+1 for x in upcast_ranges]), addrspace=AddrSpace.REG), arg=(ctx.regs,))
|
||||
buf = buf.reshape(upcast_shape)
|
||||
ctx.regs += 1
|
||||
replace_ranges = {}
|
||||
for r in upcast_ranges:
|
||||
replace_ranges[r] = UOp.range(dtypes.int, r.vmax+1, (ctx.idx, AxisType.UPCAST))
|
||||
ctx.idx += 1
|
||||
replace_ranges_v = list(replace_ranges.values())
|
||||
out_ranges = idx_sink.substitute(replace_ranges).src
|
||||
ret = load.src[0].index(*out_ranges).load()
|
||||
ret = buf.index(*upcast_ranges).load(buf.index(*replace_ranges_v).store(ret, *replace_ranges_v, tag=1))
|
||||
return ret
|
||||
else:
|
||||
return UOp(Ops.INDEX, load.src[0].dtype, src=(load.src[0],)+out_ranges).load()
|
||||
|
||||
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
||||
rngs = list(idx.src[1:])
|
||||
|
||||
input_ranges = [x for x in UOp.sink(*rngs).toposort() if x.op is Ops.RANGE and x.arg[1] != AxisType.UPCAST]
|
||||
|
||||
upcast_ranges = [x for x in UOp.sink(*rngs).toposort() if x.op is Ops.RANGE and x.arg[1] == AxisType.UPCAST]
|
||||
upcast_shape = tuple([x.vmax+1 for x in upcast_ranges])
|
||||
acc = UOp(Ops.DEFINE_REG, red.dtype.ptr(size=prod([x.vmax+1 for x in upcast_ranges]), addrspace=AddrSpace.REG),
|
||||
arg=(ctx.regs,)).reshape(upcast_shape)
|
||||
ctx.regs += 1
|
||||
|
||||
|
||||
# create reduce dims (before new upcast dims)
|
||||
new_ranges = []
|
||||
reduce_axis = 0
|
||||
for i,s in enumerate(red.src[0].shape):
|
||||
if i in red.arg[1]:
|
||||
unroll_amount = prod([o.arg if o.arg != 0 else s for o in ctx.opts if o.axis == reduce_axis and o.op == OptOps.UNROLL])
|
||||
reduce_axis += 1
|
||||
assert rngs[i].op == Ops.CONST
|
||||
#rngs[i] = UOp.range(dtypes.int, s, (ctx.idx, AxisType.REDUCE))
|
||||
#ctx.idx += 1
|
||||
if unroll_amount != 1:
|
||||
assert s%unroll_amount == 0
|
||||
rngs[i] = UOp.range(dtypes.int, s//unroll_amount, (ctx.idx, AxisType.REDUCE)) * unroll_amount
|
||||
rngs[i] = rngs[i] + UOp.range(dtypes.int, unroll_amount, (ctx.idx+1, AxisType.UNROLL))
|
||||
ctx.idx += 2
|
||||
new_ranges.extend(list(rngs[i].src))
|
||||
else:
|
||||
rngs[i] = UOp.range(dtypes.int, s, (ctx.idx, AxisType.REDUCE))
|
||||
ctx.idx += 1
|
||||
new_ranges.append(rngs[i])
|
||||
|
||||
# create new upcast dims
|
||||
replace_ranges = {}
|
||||
for r in upcast_ranges:
|
||||
replace_ranges[r] = UOp.range(dtypes.int, r.vmax+1, (ctx.idx, AxisType.UPCAST))
|
||||
ctx.idx += 1
|
||||
replace_ranges_v = list(replace_ranges.values())
|
||||
rngs = list(UOp.sink(*rngs).substitute(replace_ranges).src)
|
||||
|
||||
# identity store
|
||||
identity_ranges = []
|
||||
for r in upcast_ranges:
|
||||
identity_ranges.append(UOp.range(dtypes.int, r.vmax+1, (ctx.idx, AxisType.LOOP)))
|
||||
ctx.idx += 1
|
||||
identity = UOp.const(red.dtype, identity_element(red.arg[0], red.dtype.scalar()))
|
||||
do_identity_store = acc.index(*identity_ranges).store(identity, *identity_ranges, UOp(Ops.NOOP, src=tuple(input_ranges)), tag=1)
|
||||
|
||||
mm = UOp(Ops.INDEX, red.src[0].dtype, src=(red.src[0],)+tuple(rngs))
|
||||
rbufidx = acc.index(*replace_ranges_v)
|
||||
loaded = rbufidx.load(do_identity_store, *new_ranges)
|
||||
reduce_store = rbufidx.store(loaded.alu(red.arg[0], mm), *new_ranges, *replace_ranges_v, tag=1)
|
||||
return acc.index(*replace_ranges.keys()).load(reduce_store)
|
||||
|
||||
#return UOp(Ops.REDUCE, red.dtype, src=(mm, loaded)+tuple(replace_ranges_v)+tuple(new_ranges), arg=red.arg[0])
|
||||
|
||||
def map_reshape(x:UOp, r:UOp):
|
||||
acc = 1
|
||||
to_sum = []
|
||||
for s,src in list(zip(x.shape, x.src[1:]))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
mish = sum(to_sum)
|
||||
ret = []
|
||||
for s in x.src[0].src[0].shape[::-1]:
|
||||
if resolve(s!=1):
|
||||
# this MOD should limit any ranges outside s
|
||||
ret.append(mish % s)
|
||||
mish //= s
|
||||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
ret = UOp.sink(*ret).simplify().src[::-1] if len(ret) else ()
|
||||
return UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret))
|
||||
|
||||
def map_pad(x:UOp, r:UOp):
|
||||
ret = list(x.src[1:])
|
||||
bigwhere = UOp.const(dtypes.bool, True)
|
||||
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
if e > 0: where = where & (ret[i] < (sh-e))
|
||||
if s > 0: where = where & (ret[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
# this is safe but dumb
|
||||
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
|
||||
# mask the load
|
||||
#ret[i] = where.where(ret[i], UOp(Ops.INVALID, dtype=ret[i].dtype))
|
||||
# PAD is with 0
|
||||
return bigwhere.simplify().where(UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple(ret)), UOp.const(r.dtype, 0))
|
||||
|
||||
def capture_sink(ctx:RangeifyContext, x: UOp):
|
||||
if x.tag == 1:
|
||||
late_subs = {}
|
||||
for k,v in x.get_children_map().items():
|
||||
if k.op is Ops.CHILDREN and all([vi.op is Ops.INDEX for vi in v]):
|
||||
idxs = list(zip(*[vi.src[1:] for vi in v]))
|
||||
new_idxs = []
|
||||
only_new_idxs = []
|
||||
save_shape = []
|
||||
full_shape = []
|
||||
for idx in idxs:
|
||||
if all_same(idx):
|
||||
new_idxs.append(idx[0])
|
||||
save_shape.append(1)
|
||||
full_shape.append(idx[0].vmax+1)
|
||||
else:
|
||||
ll = [z.vmax+1 for z in idx]
|
||||
assert all_same(ll), f"mismatch shapes {ll}"
|
||||
save_shape.append(ll[0])
|
||||
full_shape.append(ll[0])
|
||||
new_idxs.append(UOp.range(dtypes.int, ll[0], (ctx.idx, AxisType.LOOP)))
|
||||
only_new_idxs.append(new_idxs[-1])
|
||||
ctx.idx += 1
|
||||
new_idxs = tuple(new_idxs)
|
||||
inp = k.src[0]
|
||||
print(save_shape, full_shape)
|
||||
if len(save_shape):
|
||||
buf = UOp(Ops.DEFINE_REG, inp.dtype.ptr(size=prod(save_shape), addrspace=AddrSpace.REG), arg=(ctx.regs,))
|
||||
ctx.regs += 1
|
||||
buf = buf.reshape(tuple(save_shape)).expand(tuple(full_shape))
|
||||
store = UOp(Ops.INDEX, buf.dtype, (buf,)+new_idxs).store(UOp(Ops.INDEX, inp.dtype, (inp,)+new_idxs), *only_new_idxs, tag=1)
|
||||
for vi in v:
|
||||
late_subs[vi] = UOp(Ops.INDEX, buf.dtype, (buf,)+vi.src[1:]).load(store)
|
||||
else:
|
||||
print("no replace")
|
||||
for vi in v:
|
||||
assert new_idxs == vi.src[1:]
|
||||
late_subs[vi] = UOp(Ops.INDEX, inp.dtype, (inp,)+new_idxs)
|
||||
if not len(late_subs): return None
|
||||
return x.substitute(late_subs)
|
||||
if x.arg is not None and x.arg.opts_to_apply is not None: ctx.opts = x.arg.opts_to_apply
|
||||
replace_children = {}
|
||||
for k,v in x.get_children_map().items():
|
||||
if k.op not in {Ops.CHILDREN, Ops.DEVICE} and len(v) > 1:
|
||||
replace_children[k] = UOp(Ops.CHILDREN, dtype=k.dtype, src=(k.replace(tag=len(v)),))
|
||||
if getenv("FUSE") and TRACK_MATCH_STATS > 0: x = x.substitute(replace_children)
|
||||
return x.replace(arg=None, tag=1)
|
||||
|
||||
pm_rangeify = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), capture_sink),
|
||||
|
||||
# TODO: handle INDEX on STORE
|
||||
(UPat(Ops.STORE, name="x"), map_store),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
||||
|
||||
# this is like the definitions of these
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([x.src[1+p] for p in argsort(x.src[0].arg)]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(x.src[1:], r.arg)]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+tuple([((s-1)-a) if f else a for a,s,f in zip(x.src[1:], r.shape, r.arg)]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="x"),
|
||||
lambda r,x: UOp(Ops.INDEX, r.dtype, src=(r.src[0],)+
|
||||
tuple([a.const_like(0) if resolve(x!=y, False) else a for a,x,y in zip(x.src[1:], r.src[0].shape, r.shape)]))),
|
||||
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="x"), map_reshape),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="x"), map_pad),
|
||||
|
||||
# bring where to the front
|
||||
#(UPat(GroupOp.Binary, name="base", src=(UPat.var("c").where(UPat.var("x"), UPat(Ops.INVALID, name="inv")), UPat.var("a"))),
|
||||
# lambda c,x,a,base,inv: c.where(UOp(base.op, base.dtype, (x,a)), inv)),
|
||||
#(UPat(GroupOp.Binary, name="base", src=(UPat.var("c").where(UPat(Ops.INVALID, name="inv"), UPat.var("x")), UPat.var("a"))),
|
||||
# lambda c,x,a,base,inv: c.where(inv, UOp(base.op, base.dtype, (x,a)))),
|
||||
#(UPat(GroupOp.Binary, name="base", src=(UPat.var("a"), UPat.var("c").where(UPat.var("x"), UPat(Ops.INVALID, name="inv")))),
|
||||
# lambda c,x,a,base,inv: c.where(UOp(base.op, base.dtype, (a,x)), inv)),
|
||||
#(UPat(GroupOp.Binary, name="base", src=(UPat.var("a"), UPat.var("c").where(UPat(Ops.INVALID, name="inv"), UPat.var("x")))),
|
||||
# lambda c,x,a,base,inv: c.where(inv, UOp(base.op, base.dtype, (a,x)))),
|
||||
|
||||
# move MAP through elementwise ALU
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE})),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[0].replace(src=tuple([UOp(Ops.INDEX, dtype=s.dtype, src=(s,)+x.src[1:]) for s in x.src[0].src]))),
|
||||
|
||||
# map load
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.LOAD, name="load"),), allow_any_len=True, name="idx"), map_load),
|
||||
|
||||
# INDEX without ranges on a DEFINE is just index 0
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines),), name="x"), lambda x: x.replace(src=x.src+(UOp.const(dtypes.int, 0),))),
|
||||
|
||||
# CONST can't have axes
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST,name="c"),)), lambda c: c),
|
||||
|
||||
# unbind...but this is too late
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR, name="v"), UPat(Ops.CONST))), lambda v: v),
|
||||
])
|
||||
|
||||
def name_the_sink(x:UOp):
|
||||
if x.arg is not None: return None
|
||||
ranges = sorted([u for u in x.toposort() if u.op is Ops.RANGE], key=lambda y: y.arg)
|
||||
return x.replace(arg=KernelInfo(name='k_'+'_'.join([colored(str(u.src[0].arg), axis_colors[u.arg[1]]) for u in ranges])))
|
||||
|
||||
pm_name = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), name_the_sink),
|
||||
])
|
||||
|
|
@ -27,8 +27,9 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
|
|||
"""
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
|
||||
if __debug__: type_verify(list(modified_ast.toposort()))
|
||||
#modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
|
||||
#if __debug__: type_verify(list(modified_ast.toposort()))
|
||||
modified_ast = ast
|
||||
|
||||
# linearize
|
||||
try:
|
||||
|
|
@ -36,7 +37,7 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
|
|||
except RuntimeError:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(f"ast = {ast}")
|
||||
print(f"opts = {modified_ast.arg.applied_opts}")
|
||||
#print(f"opts = {modified_ast.arg.applied_opts}")
|
||||
raise
|
||||
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class CStyleLanguage(Renderer):
|
|||
# naming
|
||||
prefix = None
|
||||
if u.op is Ops.SPECIAL: r[u] = u.arg[0]
|
||||
elif u.op is Ops.RANGE: r[u] = f"ridx{u.arg}"
|
||||
elif u.op is Ops.RANGE: r[u] = f"ridx{u.arg[0]}"
|
||||
else:
|
||||
prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.PRECAST: "precast",
|
||||
|
|
|
|||
|
|
@ -250,11 +250,11 @@ view_right = merge_views+PatternMatcher([
|
|||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st),)),
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).reshape(x.shape),)),
|
||||
# STORE (except for meta ops)
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda ctx,sink:
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),
|
||||
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i).reshape(s.shape), s) for i,x in enumerate(sink.src)])),
|
||||
# passthrough ASSIGN
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda x: x.src[1]),
|
||||
# VALID
|
||||
|
|
@ -294,7 +294,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
|||
# replace global memory ops with the BUFFER they write to
|
||||
ast = graph_rewrite(k.arg.ast, replace_globals, bottom_up=True, name="replace globals")
|
||||
# push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
#ast = graph_rewrite(graph_rewrite(ast, view_left, name="Main View Left"), view_right, name="Main View Right")
|
||||
# replace buffer with define_global + add load/store last
|
||||
bufs = []
|
||||
for s in k.src:
|
||||
|
|
@ -302,7 +302,7 @@ def fix_kernel_ast(k:UOp) -> UOp|None:
|
|||
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
||||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
||||
bufs.append(s)
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
||||
ast = graph_rewrite(ast, add_buffer_ops+fix_kernel_ops, bufs, bottom_up=True, name="replace buffer")
|
||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src]):
|
||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
|
@ -417,6 +417,12 @@ finalize_contiguous = PatternMatcher([
|
|||
|
||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
new_fixups = PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
|
||||
# TODO: this should be BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
|
||||
])
|
||||
|
||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}")
|
||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
"""
|
||||
|
|
@ -430,7 +436,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|||
"""
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
tensor_map = graph_rewrite_map(sink, new_fixups+multi_pm+do_fuse+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class Ops(FastEnum):
|
|||
|
||||
# buffer ops
|
||||
COPY = auto(); BUFFER = auto(); BUFFER_VIEW = auto(); MSELECT = auto(); MSTACK = auto() # noqa: E702
|
||||
CHILDREN = auto()
|
||||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); FUSE = auto() # noqa: E702
|
||||
|
|
@ -23,6 +24,7 @@ class Ops(FastEnum):
|
|||
# movement ops! these only exist in the tensor graph
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||
MULTI = auto() # MULTI is really a movement op
|
||||
INVALID = auto()
|
||||
|
||||
# view is what all movement ops become
|
||||
VIEW = auto()
|
||||
|
|
@ -82,6 +84,7 @@ class GroupOp:
|
|||
Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB, Ops.FDIV, Ops.POW}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
Elementwise = set.union(ALU, {Ops.CAST, Ops.BITCAST})
|
||||
|
||||
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||
from enum import Enum, auto
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -136,8 +136,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
@functools.cached_property
|
||||
def st(self) -> ShapeTracker|None:
|
||||
if self.op in GroupOp.Block or self.op is Ops.INDEX: return None
|
||||
if self.op in GroupOp.Block: return None
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
if self.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}:
|
||||
return ShapeTracker.from_shape((cast(PtrDType, self.dtype).size,))
|
||||
# VIEW and MovementOps define a new ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
|
|
@ -154,7 +156,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
|
||||
# otherwise we get the shape from sources
|
||||
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
if not all_same([x.shape for x in src_sts]): raise RuntimeError(f"UOp sources must have the same shape {self} {[x.shape for x in src_sts]}")
|
||||
match self.op:
|
||||
case Ops.MULTI: shape = tuple(self.src[0].shape[a]*len(self.device) if a == self.axis else s for a,s in enumerate(self.src[0].shape))
|
||||
case Ops.BITCAST:
|
||||
|
|
@ -212,11 +214,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
return ret
|
||||
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def index(self, *srcs:UOp|None): return UOp(Ops.INDEX, self.dtype, (self,)+tuple([x for x in srcs if x is not None]))
|
||||
def __getitem__(self, idx): return self.index(idx)
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=self.shape if self.st is not None else None)
|
||||
try:
|
||||
st = self.st
|
||||
except RuntimeError:
|
||||
st = None
|
||||
return UOp.const(self.dtype, b, device=self._device, shape=st.shape if st is not None else None)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
|
|
@ -248,11 +254,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
|||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
||||
if shape is not None:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
|
||||
#if shape is not None:
|
||||
#from tinygrad.shape.shapetracker import ShapeTracker
|
||||
#ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
|
||||
if device is not None:
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
#ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
# only has shape if it has device?
|
||||
if shape is not None:
|
||||
ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ try:
|
|||
(UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))),
|
||||
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", 0, x.src[0].arg-1, ctx[0]))),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg[0]}", 0, x.src[0].arg-1, ctx[0]))),
|
||||
(UPat(Ops.LOAD, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
|
||||
(UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0]),
|
||||
|
|
@ -134,7 +134,7 @@ spec = PatternMatcher([
|
|||
(UPat(Ops.DEFINE_REG, src=()), lambda: True),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, int)),
|
||||
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg[0], int)),
|
||||
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
||||
|
||||
(UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)),
|
||||
|
|
|
|||
|
|
@ -474,5 +474,5 @@ sym = symbolic_flat+PatternMatcher([
|
|||
# move const multiply after REDUCE (NOTE: the mul chain can do this, but only if it's a same dtype reduce)
|
||||
((UPat.var("x")*UPat.cvar("c", vec=False)).reduce(arg=Ops.ADD, name="r", allow_any_len=True), lambda x,c,r: r.replace(src=(x,)+r.src[1:])*c.arg),
|
||||
# reduce mul chain, move muls after the reduce
|
||||
(UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
|
||||
#(UPat(Ops.MUL).reduce(name="r", allow_any_len=True), reduce_mul_chain),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
|||
excluded: set[UOp] = set()
|
||||
for u in (toposort:=x.toposort()):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.INVALID} and u is not x: excluded.add(u)
|
||||
# only exclude CONST VIEW source if it has no other children in the graph
|
||||
if u.op is Ops.CONST and len(u.src) != 0 and all(cr.op is Ops.CONST for c in u.src[0].children if (cr:=c()) is not None and cr in toposort):
|
||||
excluded.update(u.src)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue