schedule() -> schedule_linear() in tests (batch 1) (#15915)

* schedule_with_vars -> linear_with_vars in tests

* tests batch 1

* batch 2

* estimate_uop

* simpler

* rm
This commit is contained in:
nimlgen 2026-04-24 23:40:53 +03:00 committed by GitHub
commit d3378010ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 100 additions and 87 deletions

View file

@ -16,9 +16,9 @@ class TestAttention(unittest.TestCase):
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
attn = q.scaled_dot_product_attention(k, v)
sched = attn.schedule()
sched = attn.schedule_linear()
# attention has 4 kernels now
self.assertEqual(len(sched), 4)
self.assertEqual(len(sched.src), 4)
def test_apply_rope_jit_prune(self):
def rope_fn(x_in, pos): return apply_rope(x_in, pos)

View file

@ -3,11 +3,11 @@ from contextlib import redirect_stdout
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import OSX, DEV
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import get_program
from tinygrad.engine.realize import get_program, compile_linear
class TestCompileFailures(unittest.TestCase):
def compile(self, out:Tensor):
for si in out.schedule(): si.lower()
compile_linear(out.schedule_linear())
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
def test_interpolate_atari(self):
@ -21,8 +21,8 @@ class TestDisassembly(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT in ("CPU",) and DEV.renderer not in ("LLVM", "LVP") and OSX, "m series cpus support fp16 arithmetic")
def test_float16_alu(self):
c = Tensor([1], dtype=dtypes.float16) + Tensor([1], dtype=dtypes.float16)
s = c.schedule()[-1]
p = get_program(s.ast, Device[Device.DEFAULT].renderer)
s = c.schedule_linear().src[-1]
p = get_program(s.src[0], Device[Device.DEFAULT].renderer)
lib = Device[Device.DEFAULT].compiler.compile(p.src)
out = io.StringIO()
with redirect_stdout(out): Device[Device.DEFAULT].compiler.disassemble(lib)

View file

@ -7,8 +7,8 @@ import numpy as np
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = t.schedule()
asts = [s for s in schedule if s.ast.op is Ops.SINK]
linear = t.schedule_linear()
asts = [s for s in linear.src if s.src[0].op is Ops.SINK]
len(asts)
# NOT SUPPORTED ANYMORE
#assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"

View file

@ -60,7 +60,7 @@ class TestGC(unittest.TestCase):
init = bufs_allocated()
x = Tensor.ones(256).contiguous().realize()
y = Tensor.ones(5, 5).contiguous()
y.schedule()
y.schedule_linear()
del x
del y
self.assertEqual(bufs_allocated()-init, 0)

View file

@ -9,29 +9,29 @@ class TestLinearizerRewrite(unittest.TestCase):
t = Tensor.ones((64,64), device="NULL").contiguous().realize()
out = (t*2).sum(axis=1)
with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0):
si = out.schedule()[-1]
si = out.schedule_linear().src[-1]
opts_to_apply = []
opts_to_apply.append(Opt(OptOps.UPCAST, 0, 4))
opts_to_apply.append(Opt(OptOps.UNROLL, 0, 4))
ast = si.ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
ast = si.src[0].replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
prg = get_program(ast, Device["CPU"].renderer)
print(prg.src)
def test_arange(self):
out = Tensor.arange(32, device="NULL")
with Context(SPLIT_REDUCEOP=0, DEVECTORIZE=0):
si = out.schedule()[-1]
si = out.schedule_linear().src[-1]
opts_to_apply = []
opts_to_apply.append(Opt(OptOps.UPCAST, 0, 4))
ast = si.ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
ast = si.src[0].replace(arg=KernelInfo(opts_to_apply=tuple(opts_to_apply)))
prg = get_program(ast, Device["CPU"].renderer)
print(prg.src)
def test_kernel_info(self):
out = Tensor.arange(4, device="NULL")
si = out.schedule()[-1]
si = out.schedule_linear().src[-1]
ast = si.ast.replace(arg=KernelInfo(opts_to_apply=()))
ast = si.src[0].replace(arg=KernelInfo(opts_to_apply=()))
prg = get_program(ast, Device["CPU"].renderer)
assert prg.applied_opts == (), f"expected no opts, got {prg}"

View file

@ -9,7 +9,7 @@ N = 16
class TestProcessReplay(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.ast = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
cls.ast = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule_linear().src[-1].src[0]
cls.renderer = Device[Device.DEFAULT].renderer
def test_replay_no_opts(self):
@ -35,9 +35,9 @@ class TestProcessReplay(unittest.TestCase):
def test_beam(self):
with Context(BEAM=1):
si = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1]
p = do_to_program(si.ast, self.renderer)
good, compare, _ = replay_to_program(p, si.ast, self.renderer)
ast = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule_linear().src[-1].src[0]
p = do_to_program(ast, self.renderer)
good, compare, _ = replay_to_program(p, ast, self.renderer)
self.assertEqual(good, compare)
if __name__ == '__main__':

View file

@ -87,7 +87,7 @@ class TestBufferUOp(unittest.TestCase):
# unused variable should not appear in var_vals even when there's other work
a = Tensor(UOp.variable("unused", 0, 10).bind(1))
b = Tensor.empty(3) + 1
_, var_vals = Tensor.schedule_with_vars(a, b)
_, var_vals = Tensor.linear_with_vars(a, b)
self.assertEqual(var_vals, {})
self.assertIsNone(a.uop.base.realized)
@ -208,8 +208,8 @@ class TestSchedule(unittest.TestCase):
t = Tensor.zeros((3, 3)).contiguous().realize()
v = t[1] # view - is_realized but not has_buffer_identity
assert v.uop.is_realized
sched, _ = Tensor.schedule_with_vars(v)
self.assertEqual(len(sched), 0)
linear, _ = Tensor.linear_with_vars(v)
self.assertEqual(len(linear.src), 0)
# NOTE: because empty does not have a lowered ExecItem if realize is called on a childless empty, it never gets allocated.
def test_childless_empty_never_allocates(self):

View file

@ -4,7 +4,7 @@ from tinygrad.helpers import cpu_events
from tinygrad.schedule import schedule_cache
def schedule_one():
Tensor([1]).schedule()
Tensor([1]).schedule_linear()
class TestScheduleCache(unittest.TestCase):
def test_bound_variable_var_vals(self):
@ -12,7 +12,7 @@ class TestScheduleCache(unittest.TestCase):
x = Tensor.ones(10).contiguous().realize()
t = x + Tensor(v.bind(42))
_, var_vals = t.schedule_with_vars()
_, var_vals = t.linear_with_vars()
self.assertEqual(var_vals, {'pos': 42})
def test_disable_schedule_cache(self):

View file

@ -62,11 +62,12 @@ class TestIdxUpcast(unittest.TestCase):
for src in ast.src:
if (ret:=self._find_op(src, op)) is not None: return ret
def _schedule_render(self, a: Tensor):
schedule, _ = a.schedule_with_vars()
for s in schedule:
if s.ast.op is Ops.SINK:
renderer = Device[s.bufs[0].device].renderer
prg = get_program(s.ast, renderer)
linear, _ = a.linear_with_vars()
for si in linear.src:
ast = si.src[0]
if ast.op is Ops.SINK:
renderer = Device[si.src[1].buffer.device].renderer
prg = get_program(ast, renderer)
return prg.uops
def _assert(self, dtype: DType, a: Tensor):
@ -162,9 +163,9 @@ class TestRand(unittest.TestCase):
def test_rand_large_tensor(self):
# large tensor rand (num > uint32.max) should not crash in frontend
Tensor.manual_seed(0)
Tensor.rand(2**17, 2**17).schedule()
Tensor.rand(2**17, 2**17).schedule()
Tensor.rand(2**17, 2**17).schedule()
Tensor.rand(2**17, 2**17).schedule_linear()
Tensor.rand(2**17, 2**17).schedule_linear()
Tensor.rand(2**17, 2**17).schedule_linear()
class TestTensorConstLike(unittest.TestCase):
def test_const_like_shape(self):

View file

@ -16,7 +16,7 @@ class TestTensorMutates(unittest.TestCase):
pa = a.uop
pb = b.uop
pr = ret.uop
ret.schedule()
ret.schedule_linear()
self.assertIsNot(pa, a.uop)
self.assertIsNot(pb, b.uop)
self.assertIsNot(pr, ret.uop)

View file

@ -5,22 +5,22 @@ class TestLoadStore(unittest.TestCase):
def test_load_shape(self):
t = Tensor(bytes(16)).fs_load(1024)
assert t.shape == (1024,), t.shape
t.schedule()
t.schedule_linear()
def test_store_shape(self):
t = Tensor.zeros(1024).fs_store()
assert t.shape == (16,), t.shape
t.schedule()
t.schedule_linear()
def test_load_large_shape(self):
t = Tensor(bytes(16)).fs_load(10_000_000)
assert t.shape == (10_000_000,), t.shape
t.schedule()
t.schedule_linear()
def test_store_large_shape(self):
t = Tensor.zeros(10_000_000).fs_store()
assert t.shape == (16,), t.shape
t.schedule()
t.schedule_linear()
if __name__ == "__main__":
unittest.main()

View file

@ -228,7 +228,7 @@ class TestUOpMethod(unittest.TestCase):
a = UOp.variable("a", 1, 10)
uop_var = Tensor(a.bind(1))
st_var = Tensor.empty((2, 10))[:, :a.bind(1)]
_, var_vals = (uop_var+st_var).schedule_with_vars()
_, var_vals = (uop_var+st_var).linear_with_vars()
self.assertEqual(len(var_vals), 1)
self.assertEqual(list(var_vals)[0], a.expr)

View file

@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor
from tinygrad.helpers import GlobalCounters, DEV
from tinygrad.engine.realize import get_program
from tinygrad.engine.realize import get_program, compile_linear, estimate_uop
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer import Estimates
from tinygrad.uop.ops import Ops, UOp
@ -18,8 +18,8 @@ def flops_mem(uops, ignore_indexing=False):
# **************** new FlopCounter ****************
def get_stats(x:Tensor):
si = x.schedule()[-1].lower()
return si.prg.estimates.ops, si.prg.estimates.mem
est = estimate_uop(compile_linear(x.schedule_linear()).src[-1])
return est.ops, est.mem
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does extra load/store for packed types")
class TestMemoryCount(unittest.TestCase):
@ -165,8 +165,8 @@ N = 64
class TestStatsOptimized(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast
cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule()[-1].ast
cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule_linear().src[-1].src[0]
cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule_linear().src[-1].src[0]
def check_gemm(self, p:ProgramSpec, extra_flops=0):
#p.uops.print()

View file

@ -24,8 +24,8 @@ class TestFloat4(unittest.TestCase):
b = Tensor.empty(2, 8).realize()
c = a + b
s = c.schedule()[0]
realized_ast = s.ast
s = c.schedule_linear().src[0]
realized_ast = s.src[0]
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
program = get_program(replace_opts(realized_ast, opts_to_apply), renderer=Device[Device.DEFAULT].renderer)
@ -37,8 +37,8 @@ class TestFloat4(unittest.TestCase):
b = Tensor.empty(2, 8).realize()
c = a + b
s = c.schedule()[0]
uops = get_program(replace_opts(s.ast, [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]),
s = c.schedule_linear().src[0]
uops = get_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]),
renderer=Device[Device.DEFAULT].renderer).uops
assert TestFloat4.count_float4(uops) == (4, 2)
@ -49,8 +49,8 @@ class TestFloat4(unittest.TestCase):
b = Tensor.empty(2, size).realize()
c = a + b
s = c.schedule()[0]
return get_program(replace_opts(s.ast, [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]),
s = c.schedule_linear().src[0]
return get_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]),
renderer=Device[Device.DEFAULT].renderer).uops
sizes = [12, 8, 16]
@ -66,8 +66,8 @@ class TestFloat4(unittest.TestCase):
b = Tensor.empty(9).realize().shrink(((1, 9),))
c = a + b
s = c.schedule()[0]
realized_ast = s.ast
s = c.schedule_linear().src[0]
realized_ast = s.src[0]
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
program = get_program(replace_opts(realized_ast, opts_to_apply), renderer=Device[Device.DEFAULT].renderer)
@ -79,8 +79,8 @@ class TestFloat4(unittest.TestCase):
b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
c = a + b
s = c.schedule()[0]
uops = get_program(replace_opts(s.ast, [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]),
s = c.schedule_linear().src[0]
uops = get_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]),
renderer=Device[Device.DEFAULT].renderer).uops
assert TestFloat4.count_float4(uops) == (0, 2)
@ -92,8 +92,8 @@ class TestFloat4(unittest.TestCase):
b = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
c = a + b
s = c.schedule()[0]
return get_program(replace_opts(s.ast, [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]),
s = c.schedule_linear().src[0]
return get_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]),
renderer=Device[Device.DEFAULT].renderer).uops
sizes = [13, 9, 17]
@ -111,8 +111,8 @@ class TestFloat4(unittest.TestCase):
# only the first and last conv dot products are aligned in a, and b is never aligned, so no
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = c.schedule()[0]
uops = get_program(replace_opts(s.ast, [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), renderer=Device[Device.DEFAULT].renderer).uops
s = c.schedule_linear().src[0]
uops = get_program(replace_opts(s.src[0], [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), renderer=Device[Device.DEFAULT].renderer).uops
assert TestFloat4.count_float4(uops) == (0, 0)
@ -125,8 +125,8 @@ class TestFloat4(unittest.TestCase):
# don't.
# UPDATE: now we do this fusion
s = c.schedule()[0]
uops = get_program(replace_opts(s.ast, [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]),
s = c.schedule_linear().src[0]
uops = get_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]),
renderer=Device[Device.DEFAULT].renderer).uops
assert TestFloat4.count_float4(uops) in {(0,1), (1,1)}
@ -139,8 +139,8 @@ class TestFloat4(unittest.TestCase):
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
# since the top axis is not contiguous.
s = c.schedule()[0]
uops = get_program(replace_opts(s.ast, [Opt(op=OptOps.UPCAST, axis=0, arg=4)]), renderer=Device[Device.DEFAULT].renderer).uops
s = c.schedule_linear().src[0]
uops = get_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=4)]), renderer=Device[Device.DEFAULT].renderer).uops
assert TestFloat4.count_float4(uops) == (0, 1)
@ -151,8 +151,8 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = c.schedule()[0]
uops = get_program(replace_opts(s.ast, [Opt(op=OptOps.UPCAST, axis=0, arg=4)]), renderer=Device[Device.DEFAULT].renderer).uops
s = c.schedule_linear().src[0]
uops = get_program(replace_opts(s.src[0], [Opt(op=OptOps.UPCAST, axis=0, arg=4)]), renderer=Device[Device.DEFAULT].renderer).uops
assert TestFloat4.count_float4(uops) == (1, 1)

View file

@ -24,8 +24,8 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d
ensure_triggered:bool=True):
a, b = Tensor.rand(M, K, dtype=dtype_in), Tensor.rand(K, N, dtype=dtype_in)
r = a.matmul(b, dtype=dtype_out)
sched = r.schedule()
realized_ast = sched[-1].ast
sched = r.schedule_linear()
realized_ast = sched.src[-1].src[0]
opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))]
if ensure_triggered:
@ -76,7 +76,8 @@ class TestTensorCores(unittest.TestCase):
n, m, k = tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2]
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, dtype=tc.dtype_out)
prg = get_program(replace_opts(r.schedule()[-1].ast, [Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1))]), Device[Device.DEFAULT].renderer)
prg = get_program(replace_opts(r.schedule_linear().src[-1].src[0],
[Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1))]), Device[Device.DEFAULT].renderer)
if Device.DEFAULT == "CPU" and DEV.renderer == "LLVM":
assert "0x201000" in prg.src
elif Device.DEFAULT == "AMD" and DEV.renderer == "LLVM":

View file

@ -9,9 +9,9 @@ class TestRingAllReduce(unittest.TestCase):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
schedules = t.sum(0).schedule_with_vars()[0]
copies = [si for si in schedules if si.ast.op is Ops.COPY]
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
linear = t.sum(0).linear_with_vars()[0]
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
pairs = [(c.src[1].buffer.device, c.src[2].buffer.device) for c in copies]
# N*(N-1) scatter reduce, and N*(N-1) allgather
self.assertEqual(len(pairs), N*(N-1)*2)
# copy topology forms a ring
@ -30,8 +30,8 @@ class TestAllreduceCast(unittest.TestCase):
ds = tuple(f"CPU:{i}" for i in range(2))
with Context(ALLREDUCE_CAST=allreduce_cast, RING=0, SCACHE=0):
t = Tensor.empty(4, 4, dtype=dtype).shard(ds, axis=0)
schedules = t.sum(0).schedule_with_vars()[0]
return {si.bufs[0].dtype.scalar() for si in schedules if si.ast.op is Ops.COPY}
linear = t.sum(0).linear_with_vars()[0]
return {si.src[1].buffer.dtype.scalar() for si in linear.src if si.src[0].op is Ops.COPY}
def test_allreduce_cast_bf16(self):
# with ALLREDUCE_CAST, allreduce copies stay in bfloat16 instead of promoting to float32

View file

@ -8,7 +8,7 @@ from tinygrad.engine.realize import get_program
@unittest.skipIf(Device.DEFAULT != "CPU", "only run on CPU")
class TestCPU(unittest.TestCase):
def test_arch_feats(self):
ast = (Tensor.empty(16) + Tensor.empty(16)).schedule()[-1].ast
ast = (Tensor.empty(16) + Tensor.empty(16)).schedule_linear().src[-1].src[0]
for ren in Device[Device.DEFAULT].renderers:
for arch, expect_vmov in [("x86_64,x86-64,avx", True), ("x86_64,x86-64,-avx", False)]:
with self.subTest(arch=arch):

View file

@ -1,11 +1,11 @@
from typing import TypeVar, Generic, Callable, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ, unwrap
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, pluralize, VIZ
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType, dtypes
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.engine.realize import capturing, CompiledRunner, Runner, Estimates, compile_linear, run_linear, get_runner, graph_cache
from tinygrad.engine.realize import capturing, CompiledRunner, Runner, Estimates, compile_linear, run_linear, get_runner, graph_cache, estimate_uop
from tinygrad.engine.realize import unwrap_multi, resolve_params
from tinygrad.schedule.memory import memory_plan_rewrite, _collect_bufs
from tinygrad.nn.state import get_parameters
@ -131,11 +131,7 @@ class GraphRunner(Runner):
assert p.p.local_size is not None
self.launch_dims_base[j] = (tuple(p.p.global_size), tuple(p.p.local_size))
estimates = Estimates()
for (_, ast, bufs, _), pr in zip(self.calls, self.progs):
if ast.op in (Ops.SINK, Ops.PROGRAM): estimates += unwrap(pr).estimates
elif ast.op is Ops.COPY or (ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec"):
estimates += Estimates(lds=bufs[0].nbytes, mem=bufs[0].nbytes)
estimates = sum((estimate_uop(call) for call in self.linear.src), Estimates())
# used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)

View file

@ -11,6 +11,16 @@ from tinygrad.codegen import get_program, to_program
# **************** Stat ****************
def estimate_uop(call:UOp) -> Estimates:
if call.src[0].op is Ops.SINK: call = pm_compile.rewrite(call)
ast = call.src[0]
if ast.op is Ops.PROGRAM: return ast.src[0].arg.estimates or Estimates()
if ast.op is Ops.COPY or (ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "encdec"):
nbytes = prod(call.src[1].shape) * call.src[1].dtype.itemsize
return Estimates(lds=nbytes, mem=nbytes)
return Estimates()
def update_stats(display_name:str, device:str, estimates:Estimates, var_vals:dict[str, int], et:float|None, buf_count:int,
jit=False, metadata:tuple[Metadata, ...]=(), first_run=False):
GlobalCounters.kernel_count += 1
@ -208,7 +218,7 @@ def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in call.src[1:] if b.op is not Ops.BIND]
@contextlib.contextmanager
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimates:Estimates, bufs:list[Buffer], var_vals:dict[str, int],
def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, bufs:list[Buffer], var_vals:dict[str, int],
outputs=(0,), inputs=(1,), first_run=False):
if PROFILE: cpu_events.append(ProfilePointEvent(device, "exec", len(cpu_events), {"metadata": call.arg.metadata, "var_vals": var_vals,
"bufs": [b.trace_num for b in bufs], "name": display_name, "outputs": outputs, "inputs": inputs}))
@ -219,7 +229,7 @@ def track_stats(ctx:ExecContext, call:UOp, device:str, display_name:str, estimat
if DEBUG >= 2 and timing[0] is None:
Device[device].synchronize()
timing[0] = time.perf_counter() - st
update_stats(display_name, device, estimates, var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata, first_run=first_run)
update_stats(display_name, device, estimate_uop(call), var_vals, timing[0], len(bufs), jit=ctx.jit, metadata=call.arg.metadata, first_run=first_run)
def unwrap_multi(call:UOp, resolved:list[UOp]) -> Iterator[tuple[list[Buffer], dict[str, int]]]:
bufs = [b.buffer for b in resolved]
@ -232,7 +242,7 @@ def exec_view(ctx:ExecContext, call, ast):
resolved = resolve_params(call, ctx.input_uops)
bufs = [cast(Buffer, b.buffer) for b in resolved]
bv = bufs[1].view(resolved[0].arg, ast.dtype, ast.arg[1]*bufs[1].dtype.itemsize)
with track_stats(ctx, call, bv.device, colored(f"view {bv.nbytes:8d} @ {bv.offset:<10d}", "yellow"), Estimates(), [bv, bufs[1]], ctx.var_vals):
with track_stats(ctx, call, bv.device, colored(f"view {bv.nbytes:8d} @ {bv.offset:<10d}", "yellow"), [bv, bufs[1]], ctx.var_vals):
buffers[resolved[0]] = bv
def exec_copy(ctx:ExecContext, call, ast):
@ -240,7 +250,7 @@ def exec_copy(ctx:ExecContext, call, ast):
dest, src = bufs[0].ensure_allocated(), bufs[1].ensure_allocated()
xfer = hasattr(alc:=Device[dest.device].allocator,'_transfer') and alc.supports_transfer and dest.device.split(":")[0]==src.device.split(":")[0]
prg = (BufferXfer if xfer else BufferCopy)(dest.nbytes, dest.device, src.device)
with track_stats(ctx, call, dest.device, prg.display_name, Estimates(lds=dest.nbytes, mem=dest.nbytes), [dest, src], ctx.var_vals):
with track_stats(ctx, call, dest.device, prg.display_name, [dest, src], ctx.var_vals):
prg.copy(dest, src)
def exec_kernel(ctx:ExecContext, call, ast):
@ -252,7 +262,7 @@ def exec_kernel(ctx:ExecContext, call, ast):
if VALIDATE_WITH_CPU and ast.op is Ops.SINK:
cpu_bufs = [Buffer("CPU", b.size, b.dtype).ensure_allocated().copyin(b.ensure_allocated().as_memoryview()) for b in bufs]
with track_stats(ctx, call, prg.device, prg.display_name, prg.estimates, prg_bufs, var_vals,
with track_stats(ctx, call, prg.device, prg.display_name, prg_bufs, var_vals,
outputs=tuple(prg.p.outs), inputs=tuple(prg.p.ins), first_run=prg.first_run) as timing:
timing[0] = prg(prg_bufs, var_vals, wait=DEBUG >= 2)
prg.first_run = False
@ -266,8 +276,7 @@ def exec_kernel(ctx:ExecContext, call, ast):
def exec_encdec(ctx:ExecContext, call, ast):
bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(call, ctx.input_uops)]
shape, pos_var = tuple(s.arg for s in ast.src if s.op is Ops.CONST), ast.variables()[0].expr
with track_stats(ctx, call, bufs[0].device, colored(f"enc/dec {size_to_str(bufs[0].nbytes)}", "yellow"),
Estimates(lds=bufs[0].nbytes, mem=bufs[0].nbytes), bufs, ctx.var_vals):
with track_stats(ctx, call, bufs[0].device, colored(f"enc/dec {size_to_str(bufs[0].nbytes)}", "yellow"), bufs, ctx.var_vals):
bufs[0].allocator._encode_decode(bufs[0]._buf, bufs[1]._buf, bufs[2]._buf, [x._buf for x in bufs[3:]], shape, ctx.var_vals[pos_var])
graph_cache:weakref.WeakKeyDictionary[UOp, Runner] = weakref.WeakKeyDictionary()
@ -275,7 +284,7 @@ def exec_graph(ctx:ExecContext, call, cf):
bufs = flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in (u.buffer for u in resolve_params(call, ctx.input_uops))])
if (runner:=graph_cache.get(cf)) is None:
graph_cache[cf] = runner = Device[cf.device if isinstance(cf.device, str) else cf.device[0]].graph(cf, input_uops=ctx.input_uops)
with track_stats(ctx, call, runner.device, runner.display_name, runner.estimates, bufs, ctx.var_vals) as t:
with track_stats(ctx, call, runner.device, runner.display_name, bufs, ctx.var_vals) as t:
t[0] = runner(bufs, ctx.var_vals, wait=DEBUG >= 2, input_uops=ctx.input_uops) # type: ignore[call-arg]
# ctx is beam value

View file

@ -247,6 +247,12 @@ class Tensor(OpMixin):
assert len(var_vals) == 0
return schedule
def schedule_linear(self, *lst:Tensor) -> UOp:
"""Creates the schedule needed to realize these Tensor(s)."""
linear, var_vals = self.linear_with_vars(*lst)
assert len(var_vals) == 0
return linear
@disable_gc()
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""