Compare commits

...

27 commits

Author SHA1 Message Date
George Hotz
d1223922b1 fixed and test is real 2025-12-04 16:52:11 -08:00
George Hotz
05c4b18f91
Merge branch 'master' into sched_cache 2025-12-04 16:46:23 -08:00
George Hotz
f58b3afeb2
Merge branch 'master' into sched_cache 2025-12-03 16:12:44 -08:00
George Hotz
e0a805765e full jit 2025-12-03 16:08:34 -08:00
George Hotz
7c66e44454 fix JIT in examples/gradaccum_mnist.py 2025-12-03 16:00:28 -08:00
George Hotz
e75e391ad4
Merge branch 'master' into sched_cache 2025-12-03 15:41:31 -08:00
George Hotz
8c69e26d22 metadata is best effort 2025-12-03 15:22:58 -08:00
George Hotz
74fb405cc9 reenable the actual schedule cache 2025-12-03 15:03:42 -08:00
George Hotz
bf5de6ba5f delete abstractions2 2025-12-03 15:02:20 -08:00
George Hotz
183b3ced03 fix process replay 2025-12-03 14:56:28 -08:00
George Hotz
2280dae504 src[0].op 2025-12-03 14:50:46 -08:00
George Hotz
9ba612f0b4
Merge branch 'master' into sched_cache 2025-12-03 14:50:29 -08:00
George Hotz
32794853db why is that broken? 2025-12-03 14:44:41 -08:00
George Hotz
4a72a49082
Merge branch 'master' into sched_cache 2025-12-03 14:34:49 -08:00
George Hotz
9e6f8c823d always miss 2025-12-03 14:22:26 -08:00
George Hotz
4459a88a54 fix spec 2025-12-03 14:19:07 -08:00
George Hotz
9cdda8913f put that there 2025-12-03 14:15:13 -08:00
George Hotz
e644d59f9f oops, fix cache 2025-12-03 14:07:04 -08:00
George Hotz
37a930591f preserve metadata 2025-12-03 14:04:20 -08:00
George Hotz
723179dfd6
Merge branch 'master' into sched_cache 2025-12-03 13:43:58 -08:00
George Hotz
81bafb1af3
Merge branch 'master' into sched_cache 2025-12-02 19:59:48 -08:00
George Hotz
ed89217ef2 fix tests 2025-12-02 17:14:06 -08:00
George Hotz
79f2cfcb96 schedule cache cleanup 2025-12-02 16:59:32 -08:00
George Hotz
add768aab0 schedule cache works 2025-12-02 16:40:30 -08:00
George Hotz
2d6cf839d5 local unique 2025-12-02 15:45:56 -08:00
George Hotz
b4c3a6977e
Merge branch 'master' into sched_cache 2025-12-02 12:54:14 -08:00
George Hotz
7f7aa0a7f8 start work on schedule cache 2025-12-02 07:44:10 -08:00
5 changed files with 104 additions and 20 deletions

View file

@ -1,5 +1,6 @@
import gc import gc
from tinygrad import Tensor, UOp, Device, nn from tinygrad import Tensor, UOp, Device, nn
from tinygrad.engine.schedule import schedule_cache
from tinygrad.engine.realize import method_cache, get_program from tinygrad.engine.realize import method_cache, get_program
from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape from tinygrad.schedule.indexing import apply_movement_op, _apply_reshape
from tinygrad.uop.divandmod import fold_divmod_general from tinygrad.uop.divandmod import fold_divmod_general
@ -68,6 +69,7 @@ if __name__ == "__main__":
t() t()
# these caches will keep uops alive # these caches will keep uops alive
schedule_cache.clear()
method_cache.clear() method_cache.clear()
apply_movement_op.cache_clear() apply_movement_op.cache_clear()
_apply_reshape.cache_clear() _apply_reshape.cache_clear()

View file

@ -829,6 +829,7 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(len(si.metadata), 3) self.assertEqual(len(si.metadata), 3)
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
@unittest.skip("metadata is no longer promised to be exact with schedulecache")
def test_complex_backward(self): def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True).realize() x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize() y = Tensor.rand(3, requires_grad=True).realize()

View file

@ -0,0 +1,24 @@
import unittest
from tinygrad import Tensor
from tinygrad.engine.schedule import schedule_cache
class TestScheduleCache(unittest.TestCase):
def test_simple(self):
a = Tensor.ones(10).contiguous()
b = Tensor.ones(10).contiguous()
Tensor.realize(a, b)
# warm up
for _ in range(2):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
# confirm schedule cache doesn't grow
start_len_schedule_cache = len(schedule_cache)
for _ in range(3):
num = (a.sum().contiguous()+b.sum().contiguous()).item()
print(num)
self.assertEqual(len(schedule_cache), start_len_schedule_cache)
if __name__ == "__main__":
unittest.main()

View file

@ -3,6 +3,7 @@ from typing import cast
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from collections import deque from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites
from tinygrad.uop.ops import PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
@ -113,24 +114,77 @@ from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map from tinygrad.schedule.multi import get_multi_map
def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
if (ret:=ctx.get(b, None)) is None:
if b.op is Ops.BUFFER:
ctx[b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx)), b.src[1]))
else:
# TODO: flip args in CONST
assert b.op is Ops.CONST
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
return ret
pm_pre_sched_cache = PatternMatcher([
# replace input buffers
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
# remove unique consts
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
])
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
if (ret:=ctx.get(b, None)) is None:
assert b.op is Ops.BUFFER
# if it's not in the cache, create a new buffer
ctx[b] = ret = UOp.new_buffer(b.device, b.arg, b.dtype)
return ret
pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.LUNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer_back),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.LUNIQUE)), name="b"), replace_input_buffer_back),
])
schedule_cache: dict[bytes, tuple[UOp, UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}") @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]: def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
# big_sink srcs are all the Tensors # big_sink srcs are all the Tensors
st = time.perf_counter() st = time.perf_counter()
# verify Tensors match the spec # replace all UNIQUE buffers with LUNIQUE
if SPEC: type_verify(big_sink, tensor_spec) input_buffers: dict[UOp, UOp] = {}
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
sched_cache_key = big_sink_cache.key
# tensor map is what we return if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
tensor_map: dict[UOp, UOp] = {} # verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()): # hack to preserve metadata
tensor_map |= get_multi_map(big_sink) graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx={}, name="preserve metadata")
big_sink = big_sink.substitute(tensor_map, name="Apply Multi Map")
big_sink = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink.src]))
tensor_map |= get_rangeify_map(big_sink) # tensor map is what we return
big_sink = big_sink.substitute(tensor_map, name="Apply Kernelize Map") tensor_map: dict[UOp, UOp] = {}
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
tensor_map |= get_multi_map(big_sink_cache)
big_sink_cache = big_sink_cache.substitute(tensor_map, name="Apply Multi Map")
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
tensor_map |= get_rangeify_map(big_sink_cache)
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
# save in schedule cache
tensor_map_sink = UOp.sink(*flatten([(k,v) for k,v in tensor_map.items()]))
schedule_cache[sched_cache_key] = (big_sink, tensor_map_sink)
else:
# schedule cache hit
del big_sink_cache
big_sink, tensor_map_sink = sc_ret
# replace all the LUNIQUEs with UNIQUEs
input_buffers_reverse = {v:k for k,v in input_buffers.items()}
big_sink = graph_rewrite(big_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for sched cache")
tm_src = graph_rewrite(tensor_map_sink, pm_post_sched_cache, ctx=input_buffers_reverse, name="unrewrite for tensor map").src
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# create the schedule # create the schedule
schedule, var_vals = create_schedule_with_vars(big_sink) schedule, var_vals = create_schedule_with_vars(big_sink)
@ -140,5 +194,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER} tensor_map |= {u:u.buf_uop for u in big_sink.toposort() if u.op is Ops.AFTER}
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
print(f"scheduled {len(schedule)} kernels in {(time.perf_counter()-st)*1000:.2f} ms ({len(UOpMetaClass.ucache)} uops in cache)") print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache)} uops in cache")
return tensor_map, schedule, var_vals return tensor_map, schedule, var_vals

View file

@ -5,7 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
@ -297,7 +297,7 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier) # BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
# NOTE: this has been fixed up a bit # NOTE: this has been fixed up a bit
def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=True): def bufferize_to_store(ctx:itertools.count, x:UOp, idx:UOp, allow_locals=True):
#assert isinstance(x.tag, Flat), "bufferize must be flat" #assert isinstance(x.tag, Flat), "bufferize must be flat"
size = prod(x.shape) size = prod(x.shape)
rngs = sorted(idx.ranges, key=lambda x: x.arg) rngs = sorted(idx.ranges, key=lambda x: x.arg)
@ -323,7 +323,7 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER: if x.src[0].op is Ops.REDUCE and len(x.src[0].src) == 2 and x.src[0].src[1].arg[-1] == AxisType.OUTER:
assert sdtype.addrspace == AddrSpace.GLOBAL assert sdtype.addrspace == AddrSpace.GLOBAL
outer_range = x.src[0].src[1] outer_range = x.src[0].src[1]
buf = UOp.new_buffer(x.arg.device, size, x.dtype) buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
# NOTE: this has the same number as the outer range, we need string ranges! # NOTE: this has the same number as the outer range, we need string ranges!
zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,)) zero_range = outer_range.replace(src=(UOp.const(dtypes.index, size),), arg=outer_range.arg[:-1]+(AxisType.LOOP,))
buf = buf.after(buf.index(zero_range).store(0).end(zero_range)) buf = buf.after(buf.index(zero_range).store(0).end(zero_range))
@ -333,13 +333,13 @@ def bufferize_to_store(ctx:itertools.count|None, x:UOp, idx:UOp, allow_locals=Tr
# NOTE: the DEFINE_LOCAL needs to be disambiguated here # NOTE: the DEFINE_LOCAL needs to be disambiguated here
if sdtype.addrspace == AddrSpace.GLOBAL: if sdtype.addrspace == AddrSpace.GLOBAL:
buf = UOp.new_buffer(x.arg.device, size, x.dtype) buf = UOp(Ops.BUFFER, x.dtype, (UOp(Ops.LUNIQUE, arg=next(ctx)), UOp(Ops.DEVICE, arg=x.arg.device)), size)
do_store = buf.index(idx, dtype=sdtype).store(x.src[0], tag=x.tag).end(*rngs) do_store = buf.index(idx, dtype=sdtype).store(x.src[0], tag=x.tag).end(*rngs)
return buf.after(do_store) return buf.after(do_store)
if allow_locals: if allow_locals:
# handle locals # handle locals
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=next(unwrap(ctx))) buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=next(ctx))
do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs) do_store = buf.broadcast(x.src[1].dtype.count).index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
return buf.after(do_store.barrier()) return buf.after(do_store.barrier())
@ -356,7 +356,7 @@ def flatten_bufferize(x:UOp):
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)]) pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([ pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda x, idx: bufferize_to_store(None, x, idx, allow_locals=False)), (UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda ctx,x,idx: bufferize_to_store(ctx, x, idx, allow_locals=False)),
# move RESHAPEs through MSELECT/MSTACK # move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"), (UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
@ -503,7 +503,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]) kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1])
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg) kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]): if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in kernel.src)}") raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src)}")
return kernel return kernel
split_kernels = PatternMatcher([ split_kernels = PatternMatcher([
@ -517,7 +517,7 @@ def tag_uop(ctx:list[UOp], x:UOp):
return x.replace(tag=(len(ctx)-1,)) return x.replace(tag=(len(ctx)-1,))
add_tags = PatternMatcher([ add_tags = PatternMatcher([
# don't tag BUFFERs, they are global # don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END, (UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.KERNEL, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop), Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)), (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.BUFFER for s in x.src) else tag_uop(ctx, x)),
]) ])
@ -560,7 +560,8 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify") if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
# bufferize -> store # bufferize -> store
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, bottom_up=True, name="bufferize to store") lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True, name="bufferize to store")
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels") tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign # if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign