mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
Compare commits
3 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50fd47d239 | ||
|
|
a6e976b2d8 | ||
|
|
e279c80631 |
8 changed files with 131 additions and 10 deletions
|
|
@ -71,13 +71,29 @@ class TestOuterScan(unittest.TestCase):
|
|||
ref.realize()
|
||||
return vec, mats, ref
|
||||
|
||||
def test_uop_fold_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with FOLD
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(1, 10)
|
||||
phi = Tensor(i.eq(0).where(vec.uop, out.uop))
|
||||
comp = phi @ mats[i]
|
||||
store = out.uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref[2], out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
def test_uop_scan_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with SCAN
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(3, 1, 10)
|
||||
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i]
|
||||
phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
|
||||
comp = phi @ mats[i]
|
||||
store = out[i].uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
|
|
@ -144,5 +160,84 @@ class TestOuterworld(unittest.TestCase):
|
|||
out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
|
||||
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
|
||||
|
||||
class TestVmap(unittest.TestCase):
|
||||
def test_vmap_inner(self):
|
||||
x = Tensor.ones(3, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
out = x[a]*2
|
||||
out = out.end(a)
|
||||
out.realize()
|
||||
|
||||
self.assertTrue((out==2).all().item())
|
||||
|
||||
def test_vmap_outer(self):
|
||||
x = Tensor.ones(3, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, AxisType.OUTER)
|
||||
out = x[a]*2
|
||||
out = out.end(a)
|
||||
out.realize()
|
||||
|
||||
self.assertTrue((out==2).all().item())
|
||||
|
||||
def test_fancy_vmap(self):
|
||||
def f(x,y): return x+y
|
||||
|
||||
x = Tensor.arange(9).reshape(3,3).contiguous()
|
||||
y = Tensor.arange(9).reshape(3,3).contiguous()
|
||||
|
||||
a = UOp.range(3, -1)
|
||||
out = f(x[:,a], y[a,:])
|
||||
out = out.end(a).realize()
|
||||
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
|
||||
|
||||
def test_vmap_inner_fusion(self):
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1)
|
||||
out = x[a].sum(axis=0)*2
|
||||
out = out.end(a)*4
|
||||
out.realize()
|
||||
|
||||
self.assertTrue((out==10*2*4).all().item())
|
||||
|
||||
def test_vmap_outer_fusion(self):
|
||||
x = Tensor.ones(3, 10, 2).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, AxisType.OUTER)
|
||||
out = x[a].sum(axis=0)*2
|
||||
out = out.end(a)*4
|
||||
out.realize()
|
||||
|
||||
self.assertTrue((out==10*2*4).all().item())
|
||||
|
||||
def test_vmap_outer_matmul(self):
|
||||
x = Tensor.ones(1, 10).contiguous().requires_grad_()
|
||||
mats = Tensor.ones(3, 10, 10).contiguous()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, AxisType.OUTER)
|
||||
out = x @ mats[a]
|
||||
out = out.end(a)
|
||||
|
||||
out.realize()
|
||||
|
||||
def test_vmap_outer_matmul_grad(self):
|
||||
x = Tensor.ones(1, 10).contiguous().requires_grad_()
|
||||
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
|
||||
|
||||
# vmap across axis 0
|
||||
a = UOp.range(3, -1, AxisType.OUTER)
|
||||
out = x @ mats[a]
|
||||
out = out.end(a)
|
||||
out.mean().backward()
|
||||
|
||||
mats.grad.realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -90,7 +90,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
|||
if rk.op is Ops.END: schedule.append(rk)
|
||||
else:
|
||||
raise RuntimeError(f"can't schedule {k.op}")
|
||||
for x in children[k]:
|
||||
for x in children[rk]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queues[_heuristic(x)].append(x)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ pm_gradient = PatternMatcher([
|
|||
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda: (None,)),
|
||||
# this only works on single ends of outer ranges
|
||||
(UPat(Ops.END, name="e"), lambda ctx, e: (ctx.shrink(((e.src[1],e.src[1]+1),)+(None,)*(len(ctx.shape)-1)).reshape(ctx.shape[1:]), None)),
|
||||
])
|
||||
|
||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
|||
pm_generate_realize_map = PatternMatcher([
|
||||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize),
|
||||
# always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE/END
|
||||
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.END}, name="tr"), realize),
|
||||
# realize srcs of COPY, MSELECT, MSTACK
|
||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
|
||||
# realize ASSIGN and input to assign (might be optimized out)
|
||||
|
|
@ -66,6 +66,9 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
|||
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
|
||||
del ctx.realize_map[s]
|
||||
else:
|
||||
if new_src.op is Ops.END:
|
||||
# skip END
|
||||
new_src = new_src.src[0]
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
|
|
@ -173,8 +176,12 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
|
||||
consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map]
|
||||
if x in rctx.realize_map:
|
||||
# if this is in the realize_map, we create new ranges (at the output)
|
||||
out_rngs = tuple(rctx.new_range(s) for s in x.shape)
|
||||
if x.op is Ops.END:
|
||||
# for END, we use the ranges in the src as the early ones
|
||||
out_rngs = x.src[1:]+tuple(rctx.new_range(s) for s in x.src[0].shape)
|
||||
else:
|
||||
# if this is in the realize_map, we create new ranges (at the output)
|
||||
out_rngs = tuple(rctx.new_range(s) for s in x.shape)
|
||||
# all ranges are ended now
|
||||
ending_ranges[x] = []
|
||||
# mark all ranges as ended
|
||||
|
|
@ -249,6 +256,10 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
|||
if x.op is Ops.REDUCE_AXIS:
|
||||
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
||||
|
||||
# END ends ranges
|
||||
if x.op is Ops.END:
|
||||
rngs = rngs[len(x.src)-1:]
|
||||
|
||||
if debug:
|
||||
realized_ranges = rctx.realize_map.get(x, None)
|
||||
if x.op is Ops.RESHAPE or len(rngs) != len(out_rngs):
|
||||
|
|
|
|||
|
|
@ -472,6 +472,7 @@ pm_add_range_tags = PatternMatcher([
|
|||
])
|
||||
|
||||
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
||||
# if we have any outer ranges open here, we don't split
|
||||
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
|
||||
|
||||
# ends of outer range don't go in kernels
|
||||
|
|
@ -499,7 +500,12 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
|
|||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
|
||||
return kernel
|
||||
|
||||
def split_inner_and_outer_end(x: UOp):
|
||||
outer_ranges, inner_ranges = partition(x.src[1:], lambda r: r.arg[-1] == AxisType.OUTER)
|
||||
if len(outer_ranges) and len(inner_ranges): return x.src[0].end(*inner_ranges).end(*outer_ranges)
|
||||
|
||||
split_kernels = PatternMatcher([
|
||||
(UPat(Ops.END, name="x"), split_inner_and_outer_end),
|
||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
])
|
||||
|
||||
|
|
@ -510,7 +516,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
|
|||
return x.replace(tag=(len(ctx)-1,))
|
||||
add_tags = PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
|
||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
|
|
|||
|
|
@ -481,6 +481,9 @@ class Tensor(OpMixin):
|
|||
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
raise RuntimeError(f"unhandled UOp {y}")
|
||||
|
||||
def end(self, *rngs:UOp):
|
||||
return self._apply_uop(UOp.end, extra_args=rngs, dtype=self.dtype)
|
||||
|
||||
# ***** creation entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -219,9 +219,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER:
|
||||
return self.src[0]._shape
|
||||
|
||||
# end adds dims to the front
|
||||
case Ops.END:
|
||||
return None if self.src[0]._shape is None else (tuple(x.vmax+1 for x in self.src[1:]) + self.src[0]._shape)
|
||||
|
||||
# ops with custom handling
|
||||
case Ops.KERNEL: return self.arg.ast._shape
|
||||
|
||||
|
|
@ -398,9 +402,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, src:UOp|ConstType, **kwargs):
|
||||
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
|
||||
def end(self, *src:UOp):
|
||||
def end(self, *src:UOp, **kwargs):
|
||||
if len(src) == 0: return self
|
||||
return UOp(Ops.END, src=(self,)+src)
|
||||
return UOp(Ops.END, src=(self,)+src, **kwargs)
|
||||
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue