Compare commits

...

27 commits

Author SHA1 Message Date
George Hotz
d1223922b1 fixed and test is real 2025-12-04 16:52:11 -08:00
George Hotz
05c4b18f91
Merge branch 'master' into sched_cache 2025-12-04 16:46:23 -08:00
George Hotz
f58b3afeb2
Merge branch 'master' into sched_cache 2025-12-03 16:12:44 -08:00
George Hotz
e0a805765e full jit 2025-12-03 16:08:34 -08:00
George Hotz
7c66e44454 fix JIT in examples/gradaccum_mnist.py 2025-12-03 16:00:28 -08:00
George Hotz
e75e391ad4
Merge branch 'master' into sched_cache 2025-12-03 15:41:31 -08:00
George Hotz
8c69e26d22 metadata is best effort 2025-12-03 15:22:58 -08:00
George Hotz
74fb405cc9 reenable the actual schedule cache 2025-12-03 15:03:42 -08:00
George Hotz
bf5de6ba5f delete abstractions2 2025-12-03 15:02:20 -08:00
George Hotz
183b3ced03 fix process replay 2025-12-03 14:56:28 -08:00
George Hotz
2280dae504 src[0].op 2025-12-03 14:50:46 -08:00
George Hotz
9ba612f0b4
Merge branch 'master' into sched_cache 2025-12-03 14:50:29 -08:00
George Hotz
32794853db why is that broken? 2025-12-03 14:44:41 -08:00
George Hotz
4a72a49082
Merge branch 'master' into sched_cache 2025-12-03 14:34:49 -08:00
George Hotz
9e6f8c823d always miss 2025-12-03 14:22:26 -08:00
George Hotz
4459a88a54 fix spec 2025-12-03 14:19:07 -08:00
George Hotz
9cdda8913f put that there 2025-12-03 14:15:13 -08:00
George Hotz
e644d59f9f oops, fix cache 2025-12-03 14:07:04 -08:00
George Hotz
37a930591f preserve metadata 2025-12-03 14:04:20 -08:00
George Hotz
723179dfd6
Merge branch 'master' into sched_cache 2025-12-03 13:43:58 -08:00
George Hotz
81bafb1af3
Merge branch 'master' into sched_cache 2025-12-02 19:59:48 -08:00
George Hotz
ed89217ef2 fix tests 2025-12-02 17:14:06 -08:00
George Hotz
79f2cfcb96 schedule cache cleanup 2025-12-02 16:59:32 -08:00
George Hotz
add768aab0 schedule cache works 2025-12-02 16:40:30 -08:00
George Hotz
2d6cf839d5 local unique 2025-12-02 15:45:56 -08:00
George Hotz
b4c3a6977e
Merge branch 'master' into sched_cache 2025-12-02 12:54:14 -08:00
George Hotz
7f7aa0a7f8 start work on schedule cache 2025-12-02 07:44:10 -08:00
5 changed files with 104 additions and 20 deletions

View file

@ -1,5 +1,6 @@
import gc
from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.schedule import schedule_cache
from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general
@ -68,6 +69,7 @@ if __name__ == "__main__":
t()
# these caches will keep uops alive
schedule_cache.clear()
method_cache.clear()
apply_movement_op.cache_clear()
_apply_reshape.cache_clear()

View file

@ -829,6 +829,7 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()

View file

@ -0,0 +1,24 @@
import unittest
from tinygrad import Tensor
from tinygrad.engine.schedule import schedule_cache
class TestScheduleCache(unittest.TestCase):
def test_simple(self):
a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous()
Tensor.realize(a, b)
# warm up
for _ in range(2):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
# confirm schedule cache doesn't grow
start_len_schedule_cache = len(schedule_cache)
for _ in range(3):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
if __name__ == "__main__":
unittest.main()

View file

@ -3,6 +3,7 @@ from typing import cast
from dataclasses import dataclass, field, replace
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites
from tinygrad.uop.ops import PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
@ -113,24 +114,77 @@ from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
if (ret:=ctx.get(b, None)) is None:
if b.op is Ops.BUFFER:
ctx[b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx)), b.src[1]))
else:
# TODO: flip args in CONST
assert b.op is Ops.CONST
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
return ret
pm_pre_sched_cache = PatternMatcher([
# replace input buffers
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# remove unique consts
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
])
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
if (ret:=ctx.get(b, None)) is None:
assert b.op is Ops.BUFFER
# if it's not in the cache, create a new buffer
ctx[b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
])
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
# verify Tensors match the spec
if SPEC: type_verify(big_sink, tensor_spec)
# replace all UNIQUE buffers with LUNIQUE
input_buffers: dict[UOp, UOp] = {}
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
sched_cache_key = big_sink_cache.key
# tensor map is what we return
tensor_map: dict[UOp, UOp] = {}
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
tensor_map |= get_multi_map(big_sink)
big_sink = big_sink.substitute(tensor_map, name="Apply Multi Map")
big_sink = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink.src]))
# hack to preserve metadata
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx={}, name="preserve metadata")
tensor_map |= get_rangeify_map(big_sink)
big_sink = big_sink.substitute(tensor_map, name="Apply Kernelize Map")
# tensor map is what we return
tensor_map: dict[UOp, UOp] = {}
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
tensor_map |= get_multi_map(big_sink_cache)
big_sink_cache = big_sink_cache.substitute(tensor_map, name="Apply Multi Map")
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
tensor_map |= get_rangeify_map(big_sink_cache)
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
# save in schedule cache
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
schedule_cache[sched_cache_key] = (big_sink, tensor_map_sink)
else:
# schedule cache hit
del big_sink_cache
big_sink, tensor_map_sink = sc_ret
# replace all the LUNIQUEs with UNIQUEs
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
big_sink = graph_rewrite(big_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for sched cache")
tm_src = graph_rewrite(tensor_map_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for tensor map").src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# create the schedule
schedule, var_vals = create_schedule_with_vars(big_sink)
@ -140,5 +194,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)")
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
return tensor_map, schedule, var_vals

View file

@ -5,7 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@ -297,7 +297,7 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
# NOTE: this has been fixed up a bit
def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=True):
def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
#assert isinstance(x.tag, Flat), "bufferize must be flat"
size = prod(x.shape)
rngs = sorted(idx.ranges, key=lambda x: x.arg)
@ -323,7 +323,7 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
assert sdtype.addrspace == AddrSpace.GLOBAL
outer_range = x.src[0].src[1]
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
# NOTE: this has the same number as the outer range, we need string ranges!
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
@ -333,13 +333,13 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
do_store = buf.index(idx, dtype=sdtype).store(x.src[0], tag=x.tag).end(*rngs)
return buf.after(do_store)
if allow_locals:
# handle locals
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=next(unwrap(ctx)))
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=next(ctx))
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store.barrier())
@ -356,7 +356,7 @@ def flatten_bufferize(x:UOp):
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda x, idx: bufferize_to_store(None, x, idx, allow_locals=False)),
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
# move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
@ -503,7 +503,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}")
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
return kernel
split_kernels = PatternMatcher([
@ -517,7 +517,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
])
@ -560,7 +560,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
# bufferize -> store
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, bottom_up=True, name="bufferize to store")
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign