Compare commits

...

3 commits

Author SHA1 Message Date
George Hotz
50fd47d239 never mind, i don't like this 2025-11-17 19:47:08 -08:00
George Hotz
a6e976b2d8 works 2025-11-17 19:22:26 -08:00
George Hotz
e279c80631 outer vmap 2025-11-17 19:01:14 -08:00
8 changed files with 131 additions and 10 deletions

View file

@ -71,13 +71,29 @@ class TestOuterScan(unittest.TestCase):
ref.realize() ref.realize()
return vec, mats, ref return vec, mats, ref
def test_uop_fold_matmul(self):
vec, mats, ref = self._test_scan()
# 3 matmuls with FOLD
i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(1, 10)
phi = Tensor(i.eq(0).where(vec.uop, out.uop))
comp = phi @ mats[i]
store = out.uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref[2], out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
def test_uop_scan_matmul(self): def test_uop_scan_matmul(self):
vec, mats, ref = self._test_scan() vec, mats, ref = self._test_scan()
# 3 matmuls with SCAN # 3 matmuls with SCAN
i = UOp.range(3, -100, AxisType.OUTER) i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(3, 1, 10) out = Tensor.empty(3, 1, 10)
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i] phi = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop))
comp = phi @ mats[i]
store = out[i].uop.store(comp.uop).end(i) store = out[i].uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store)) out = Tensor(out.uop.after(store))
out.realize() out.realize()
@ -144,5 +160,84 @@ class TestOuterworld(unittest.TestCase):
out = out.reshape(1, 3).expand(a, 3).contiguous().realize() out = out.reshape(1, 3).expand(a, 3).contiguous().realize()
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist()) self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
class TestVmap(unittest.TestCase):
def test_vmap_inner(self):
x = Tensor.ones(3, 2).contiguous()
# vmap across axis 0
a = UOp.range(3, -1)
out = x[a]*2
out = out.end(a)
out.realize()
self.assertTrue((out==2).all().item())
def test_vmap_outer(self):
x = Tensor.ones(3, 2).contiguous()
# vmap across axis 0
a = UOp.range(3, -1, AxisType.OUTER)
out = x[a]*2
out = out.end(a)
out.realize()
self.assertTrue((out==2).all().item())
def test_fancy_vmap(self):
def f(x,y): return x+y
x = Tensor.arange(9).reshape(3,3).contiguous()
y = Tensor.arange(9).reshape(3,3).contiguous()
a = UOp.range(3, -1)
out = f(x[:,a], y[a,:])
out = out.end(a).realize()
self.assertListEqual([[0,4,8],[4,8,12],[8,12,16]], out.tolist())
def test_vmap_inner_fusion(self):
x = Tensor.ones(3, 10, 2).contiguous()
# vmap across axis 0
a = UOp.range(3, -1)
out = x[a].sum(axis=0)*2
out = out.end(a)*4
out.realize()
self.assertTrue((out==10*2*4).all().item())
def test_vmap_outer_fusion(self):
x = Tensor.ones(3, 10, 2).contiguous()
# vmap across axis 0
a = UOp.range(3, -1, AxisType.OUTER)
out = x[a].sum(axis=0)*2
out = out.end(a)*4
out.realize()
self.assertTrue((out==10*2*4).all().item())
def test_vmap_outer_matmul(self):
x = Tensor.ones(1, 10).contiguous().requires_grad_()
mats = Tensor.ones(3, 10, 10).contiguous()
# vmap across axis 0
a = UOp.range(3, -1, AxisType.OUTER)
out = x @ mats[a]
out = out.end(a)
out.realize()
def test_vmap_outer_matmul_grad(self):
x = Tensor.ones(1, 10).contiguous().requires_grad_()
mats = Tensor.ones(3, 10, 10).contiguous().requires_grad_()
# vmap across axis 0
a = UOp.range(3, -1, AxisType.OUTER)
out = x @ mats[a]
out = out.end(a)
out.mean().backward()
mats.grad.realize()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View file

View file

@ -90,7 +90,7 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
if rk.op is Ops.END: schedule.append(rk) if rk.op is Ops.END: schedule.append(rk)
else: else:
raise RuntimeError(f"can't schedule {k.op}") raise RuntimeError(f"can't schedule {k.op}")
for x in children[k]: for x in children[rk]:
in_degree[x] -= 1 in_degree[x] -= 1
if in_degree[x] == 0: queues[_heuristic(x)].append(x) if in_degree[x] == 0: queues[_heuristic(x)].append(x)

View file

@ -43,6 +43,8 @@ pm_gradient = PatternMatcher([
(UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)), (UPat(Ops.KERNEL, name="k"), lambda ctx, k: k.arg.grad_fxn(ctx, k)),
# there's no gradient for bitcast # there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda: (None,)), (UPat(Ops.BITCAST), lambda: (None,)),
# this only works on single ends of outer ranges
(UPat(Ops.END, name="e"), lambda ctx, e: (ctx.shrink(((e.src[1],e.src[1]+1),)+(None,)*(len(ctx.shape)-1)).reshape(ctx.shape[1:]), None)),
]) ])
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]: def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:

View file

@ -24,8 +24,8 @@ def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
pm_generate_realize_map = PatternMatcher([ pm_generate_realize_map = PatternMatcher([
# always realize SINK src # always realize SINK src
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
# always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE # always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE/END
(UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE}, name="tr"), realize), (UPat({Ops.COPY, Ops.BUFFER_VIEW, Ops.CONTIGUOUS, Ops.STORE, Ops.END}, name="tr"), realize),
# realize srcs of COPY, MSELECT, MSTACK # realize srcs of COPY, MSELECT, MSTACK
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs), (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_srcs),
# realize ASSIGN and input to assign (might be optimized out) # realize ASSIGN and input to assign (might be optimized out)
@ -66,6 +66,9 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE]) new_src = s.end(*[r for r in closed_ranges if r.op is Ops.RANGE])
del ctx.realize_map[s] del ctx.realize_map[s]
else: else:
if new_src.op is Ops.END:
# skip END
new_src = new_src.src[0]
# None in the device assigns it a number later # None in the device assigns it a number later
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL) opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None) new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
@ -173,8 +176,12 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map] consumer_rngs = [rctx.range_map[c][0] for c in consumer_map[x] if c in rctx.range_map]
if x in rctx.realize_map: if x in rctx.realize_map:
# if this is in the realize_map, we create new ranges (at the output) if x.op is Ops.END:
out_rngs = tuple(rctx.new_range(s) for s in x.shape) # for END, we use the ranges in the src as the early ones
out_rngs = x.src[1:]+tuple(rctx.new_range(s) for s in x.src[0].shape)
else:
# if this is in the realize_map, we create new ranges (at the output)
out_rngs = tuple(rctx.new_range(s) for s in x.shape)
# all ranges are ended now # all ranges are ended now
ending_ranges[x] = [] ending_ranges[x] = []
# mark all ranges as ended # mark all ranges as ended
@ -249,6 +256,10 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
if x.op is Ops.REDUCE_AXIS: if x.op is Ops.REDUCE_AXIS:
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape))) rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
# END ends ranges
if x.op is Ops.END:
rngs = rngs[len(x.src)-1:]
if debug: if debug:
realized_ranges = rctx.realize_map.get(x, None) realized_ranges = rctx.realize_map.get(x, None)
if x.op is Ops.RESHAPE or len(rngs) != len(out_rngs): if x.op is Ops.RESHAPE or len(rngs) != len(out_rngs):

View file

@ -472,6 +472,7 @@ pm_add_range_tags = PatternMatcher([
]) ])
def split_store(ctx:list[UOp], x:UOp) -> UOp|None: def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
# if we have any outer ranges open here, we don't split
if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None if len([r for r in x.ranges if r.arg[-1] != AxisType.OUTER]): return None
# ends of outer range don't go in kernels # ends of outer range don't go in kernels
@ -499,7 +500,12 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}") raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
return kernel return kernel
def split_inner_and_outer_end(x: UOp):
outer_ranges, inner_ranges = partition(x.src[1:], lambda r: r.arg[-1] == AxisType.OUTER)
if len(outer_ranges) and len(inner_ranges): return x.src[0].end(*inner_ranges).end(*outer_ranges)
split_kernels = PatternMatcher([ split_kernels = PatternMatcher([
(UPat(Ops.END, name="x"), split_inner_and_outer_end),
(UPat((Ops.STORE, Ops.END), name="x"), split_store), (UPat((Ops.STORE, Ops.END), name="x"), split_store),
]) ])
@ -510,7 +516,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
return x.replace(tag=(len(ctx)-1,)) return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([ add_tags = PatternMatcher([
# don't tag BUFFERs, they are global # don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END, (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop), Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)), (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
]) ])

View file

@ -481,6 +481,9 @@ class Tensor(OpMixin):
if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
raise RuntimeError(f"unhandled UOp {y}") raise RuntimeError(f"unhandled UOp {y}")
def end(self, *rngs:UOp):
return self._apply_uop(UOp.end, extra_args=rngs, dtype=self.dtype)
# ***** creation entrypoint ***** # ***** creation entrypoint *****
@staticmethod @staticmethod

View file

@ -219,9 +219,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,) case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return (self.ptrdtype.size,)
# passthrough ops # passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER:
return self.src[0]._shape return self.src[0]._shape
# end adds dims to the front
case Ops.END:
return None if self.src[0]._shape is None else (tuple(x.vmax+1 for x in self.src[1:]) + self.src[0]._shape)
# ops with custom handling # ops with custom handling
case Ops.KERNEL: return self.arg.ast._shape case Ops.KERNEL: return self.arg.ast._shape
@ -398,9 +402,9 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs) def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
def store(self, src:UOp|ConstType, **kwargs): def store(self, src:UOp|ConstType, **kwargs):
return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs) return UOp(Ops.STORE, kwargs.pop("dtype", dtypes.void), (self, UOp.const(self.dtype, src) if not isinstance(src, UOp) else src), **kwargs)
def end(self, *src:UOp): def end(self, *src:UOp, **kwargs):
if len(src) == 0: return self if len(src) == 0: return self
return UOp(Ops.END, src=(self,)+src) return UOp(Ops.END, src=(self,)+src, **kwargs)
def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs) def after(self, *src:UOp, **kwargs): return UOp(Ops.AFTER, self.dtype, (self,)+src, **kwargs)
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x)) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src) def barrier(self, *src:UOp): return UOp(Ops.BARRIER, src=(self,)+src)