remove linear_to_schedule from tests (#15912)

* remove linear_to_schedule from tests

* x
This commit is contained in:
nimlgen 2026-04-24 20:02:10 +03:00 committed by GitHub
commit f2751955cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 53 additions and 57 deletions

View file

@ -73,9 +73,11 @@ if __name__ == "__main__":
A, B = Tensor.normal(M, K, std=1e-1, dtype=dtypes.float16).realize(), Tensor.normal(K, N, std=1e-1, dtype=dtypes.float16).realize()
C = A.matmul(B)
from tinygrad.schedule import linear_to_schedule
from tinygrad.uop.ops import Ops
linear, var_vals = C.linear_with_vars()
si = linear_to_schedule(linear)[-1]
last_call = linear.src[-1]
ast = last_call.src[0]
bufs = [s.buffer for s in last_call.src[1:] if s.op is not Ops.BIND]
src = compiled.asm["ptx"]
# specify the shared memory here so we don't need to do it dynamically
@ -89,7 +91,7 @@ if __name__ == "__main__":
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
ei = ExecItem(ast, [x.ensure_allocated() for x in bufs], last_call.arg.metadata, prg=CompiledRunner(prg))
tflops = []
for i in range(5):
tm = ei.run(wait=True)

View file

@ -3,7 +3,6 @@ import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
from tinygrad.helpers import Context, getenv, DEV
from tinygrad.engine.realize import run_linear
from tinygrad.schedule import linear_to_schedule
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.schedule import ExecItem
from tinygrad.renderer import Estimates
@ -56,7 +55,7 @@ class TestIndexing(unittest.TestCase):
GlobalCounters.reset()
out = ((Tensor.arange(1,16385)-1)*needle).sum()
linear, var_vals = out.linear_with_vars()
self.assertEqual(len(linear_to_schedule(linear)), 1)
self.assertEqual(len(linear.src), 1)
run_linear(linear, var_vals)
self.assertEqual(out.item(), 1337)
@ -73,7 +72,7 @@ class TestIndexing(unittest.TestCase):
full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, DDIM, DSET, 1))
X = full.sum(axis=(2,3))
linear, var_vals = X.linear_with_vars()
self.assertEqual(len(linear_to_schedule(linear)), 1)
self.assertEqual(len(linear.src), 1)
run_linear(linear, var_vals)
assert GlobalCounters.global_ops < 4*DSET, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(real_index, X.numpy())
@ -99,7 +98,7 @@ class TestIndexing(unittest.TestCase):
X = dataset[idxs]
assert X.shape == (4,DDIM)
linear, var_vals = X.linear_with_vars()
self.assertEqual(len(linear_to_schedule(linear)), 1)
self.assertEqual(len(linear.src), 1)
run_linear(linear, var_vals)
assert GlobalCounters.global_ops < 4*DSET, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(real_index, X.numpy())
@ -114,7 +113,7 @@ class TestIndexing(unittest.TestCase):
X = dataset[idxs]
assert X.shape == (4,DDIM)
linear, var_vals = X.linear_with_vars()
self.assertEqual(len(linear_to_schedule(linear)), 1)
self.assertEqual(len(linear.src), 1)
run_linear(linear, var_vals)
assert GlobalCounters.global_ops < 4*DSET, f"too many ops {GlobalCounters.global_ops} != {4*DSET}"
np.testing.assert_allclose(real_index, X.numpy())

View file

@ -7,7 +7,6 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_linear, CompiledRunner, get_program
from tinygrad.schedule import linear_to_schedule
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, DEV
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
from tinygrad.renderer.ptx import PTXRenderer
@ -288,11 +287,10 @@ class TestLinearizer(unittest.TestCase):
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(b.where(2, a))
linear, var_vals = a.linear_with_vars()
sched_copy = linear_to_schedule(linear)
assert len(sched_copy) == 1
assert len(linear.src) == 1
run_linear(linear, var_vals)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
program = get_program(replace_opts(sched_copy[-1].ast, []), renderer=Device[Device.DEFAULT].renderer)
program = get_program(replace_opts(linear.src[-1].src[0], []), renderer=Device[Device.DEFAULT].renderer)
assert not any(u.op == Ops.WHERE for u in program.uops), "found where where where should be folded"
def test_phi_simplification(self):
@ -390,15 +388,17 @@ class TestLinearizer(unittest.TestCase):
def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
if isinstance(r, Tensor): r = [r]
linear, var_vals = Tensor.linear_with_vars(*r)
s = linear_to_schedule(linear)
run_linear(UOp(Ops.LINEAR, src=linear.src[:-1]), var_vals) # run all kernels except the last one
assert s[-1].ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {s[-1]}"
# now all input buffers in s[-1] should be realized
last_call = linear.src[-1]
ast = last_call.src[0]
assert ast.op is Ops.SINK, f"helper_realized_ast expects a SINK {last_call}"
last_bufs = [s.buffer for s in last_call.src[1:] if s.op is not Ops.BIND]
# now all input buffers in last_call should be realized
# create fresh buffers for the outputs
bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
bufs = [Buffer(x.device, x.size, x.dtype).allocate() if i < len(ast.src) else x for i,x in enumerate(last_bufs)]
# ensure buffers are allocated
for b in bufs: b.ensure_allocated()
return s[-1].ast, bufs
return ast, bufs
def helper_linearizer_ast(ast:UOp, inputs:list[Tensor], *args, **kwargs):
assert isinstance(ast, UOp), "ast must be UOp"

View file

@ -5,7 +5,6 @@ from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import CompiledRunner, run_linear
from tinygrad.schedule import linear_to_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
from test.helpers import not_support_multi_device, needs_second_gpu, slow, call_is_graph
@ -194,9 +193,9 @@ class TestMultiTensor(unittest.TestCase):
for i in range(2):
xt = X[i*2:i*2+2].contiguous()
linear, var_vals = xt.linear_with_vars()
#kernels = [s for s in linear_to_schedule(linear) if s.ast.op is Ops.SINK]
#kernels = [call for call in linear.src if call.src[0].op is Ops.SINK]
#self.assertEqual(len(kernels), 1)
#self.assertEqual(kernels[0].bufs[0].device, devices_2[i])
#self.assertEqual(kernels[0].src[1].buffer.device, devices_2[i])
run_linear(linear, var_vals)
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
@ -809,7 +808,7 @@ class TestMultiTensor(unittest.TestCase):
t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0)
out = Tensor.full_like(t, 2)[:, :8]
linear, var_vals = out.linear_with_vars()
self.assertEqual(len(linear_to_schedule(linear)), 0)
self.assertEqual(len(linear.src), 0)
run_linear(linear, var_vals)
self.assertEqual(out.tolist(), [[2]*8]*16)
@ -1159,12 +1158,12 @@ class TestMultiBufferView(unittest.TestCase):
def setUp(self): pass
def _check(self, a_ref:Tensor, a_multi:Tensor, view_fn):
"""Apply view_fn to both, verify zero compiled kernels and matching values."""
b_ref = view_fn(a_ref)
b_multi = view_fn(a_multi).contiguous()
linear, var_vals = b_multi.linear_with_vars()
compiled = [si for si in linear_to_schedule(linear) if isinstance(si.prg, CompiledRunner)]
self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}")
if all(hasattr(Device[d].allocator, "_offset") for d in b_multi.device):
compiled = [call for call in linear.src if call.src[0].op is Ops.SINK]
self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}")
run_linear(linear, var_vals)
np.testing.assert_equal(b_multi.numpy(), b_ref.numpy())
@ -1192,11 +1191,13 @@ class TestMultiBufferView(unittest.TestCase):
def test_4_devices(self):
ref = Tensor.arange(8*12).reshape(8, 12).contiguous().realize()
a = Tensor.arange(8*12).reshape(8, 12).contiguous().shard(devices_4, axis=1).realize()
linear, var_vals = a[5].contiguous().linear_with_vars()
compiled = [si for si in linear_to_schedule(linear) if isinstance(si.prg, CompiledRunner)]
self.assertEqual(len(compiled), 0)
out = a[5].contiguous()
linear, var_vals = out.linear_with_vars()
if all(hasattr(Device[d].allocator, "_offset") for d in out.device):
compiled = [call for call in linear.src if call.src[0].op is Ops.SINK]
self.assertEqual(len(compiled), 0)
run_linear(linear, var_vals)
np.testing.assert_equal(a[5].contiguous().numpy(), ref[5].numpy())
np.testing.assert_equal(out.numpy(), ref[5].numpy())
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiFromUnrenderable(unittest.TestCase):

View file

@ -9,7 +9,6 @@ from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.realize import run_linear
from tinygrad.schedule import linear_to_schedule
from test.helpers import not_support_multi_device, needs_second_gpu, slow
@slow
@ -433,7 +432,7 @@ class TestNN(unittest.TestCase):
[12, 19, 8, 1]])
result = layer(a)
linear, var_vals = result.linear_with_vars()
self.assertEqual(len([item for item in linear_to_schedule(linear) if item.ast.op is Ops.SINK]), kcount,
self.assertEqual(len([call for call in linear.src if call.src[0].op is Ops.SINK]), kcount,
"first run realizes weight and embedding")
run_linear(linear, var_vals)
@ -442,7 +441,7 @@ class TestNN(unittest.TestCase):
[7, 8, 9]])
result = layer(b)
linear, var_vals = result.linear_with_vars()
self.assertEqual(1, len([item for item in linear_to_schedule(linear) if item.ast.op is Ops.SINK]),
self.assertEqual(1, len([call for call in linear.src if call.src[0].op is Ops.SINK]),
"second run realizes embedding only")
run_linear(linear, var_vals)
print(f"Embedding used {GlobalCounters.global_ops} ops")

View file

@ -12,8 +12,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Ops, UPat
from tinygrad.helpers import CI, DEBUG, OSX, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.realize import CompiledRunner, run_linear
from tinygrad.schedule import linear_to_schedule
from tinygrad.engine.realize import CompiledRunner, compile_linear, run_linear
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
@ -24,17 +23,17 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
else:
assert isinstance(t, UOp), f"can't schedule {t}"
linear, var_vals = Tensor(t).linear_with_vars()
# test lowering all the ExecItems
sched = linear_to_schedule(linear)
for si in sched: si.lower()
kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink])
kernel_cnt = sum((len(call.device) if isinstance(call.device, tuple) else 1)
for call in linear.src if call.src[0].op is Ops.SINK or not filter_sink)
if kernel_cnt != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
if DEBUG >= 3:
for i,s in enumerate(sched):
for i,call in enumerate(linear.src):
print("kernel", i+1)
print(s.ast)
print(call.src[0])
raise KernelCountException(f"{kernel_cnt} != {allowed}")
# test compiling the linear
compile_linear(linear)
return linear, var_vals
def _realize_weights(m):
@ -50,9 +49,8 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float):
ret = Tensor.conv2d(img, w).relu().mean().backward()
dtypes.default_float = old_default_float
linear, var_vals = Tensor.linear_with_vars(ret, img.grad, w.grad)
s = linear_to_schedule(linear)
run_linear(linear, var_vals)
cnt = len([si for si in s if si.ast.op is Ops.SINK])
cnt = len([call for call in linear.src if call.src[0].op is Ops.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
if getenv("CHECK", 1):
import torch
@ -74,7 +72,7 @@ class TestSchedule(unittest.TestCase):
x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32)
t = x.avg_pool2d(padding=1)
linear, var_vals = t.linear_with_vars()
self.assertEqual(len(linear_to_schedule(linear)), kcount)
self.assertEqual(len(linear.src), kcount)
run_linear(linear, var_vals)
import torch
torch_out = torch.nn.functional.avg_pool2d(torch.arange(25).reshape(1,1,5,5).float(), kernel_size=(2,2), padding=1).numpy()
@ -1055,7 +1053,7 @@ class TestSchedule(unittest.TestCase):
expected = (a+a2).tolist()
a.assign(a+a2)
linear, var_vals = a.linear_with_vars()
kcount = len(linear_to_schedule(linear))
kcount = len(linear.src)
run_linear(linear, var_vals)
self.assertListEqual(a.tolist(), expected)
self.assertEqual(kcount, expected_kcount)
@ -1356,7 +1354,7 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.ones(4).contiguous().realize().uop.buf_uop
t = Tensor(a.copy_to_device(a.device))
linear, var_vals = t.linear_with_vars()
assert len([s for s in linear_to_schedule(linear) if s.ast.op is Ops.COPY]) == 0
assert len([call for call in linear.src if call.src[0].op is Ops.COPY]) == 0
run_linear(linear, var_vals)
assert t.uop.is_realized, f"didn't realize Tensor {t}"
self.assertListEqual(t.tolist(), [1.,1.,1.,1.])

View file

@ -7,7 +7,6 @@ from tinygrad import GlobalCounters, Tensor, Device
from tinygrad.helpers import getenv
from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import capturing, run_linear
from tinygrad.schedule import linear_to_schedule
from tinygrad.tensor import _to_np_dtype
class CLCache:
@ -15,7 +14,7 @@ class CLCache:
self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
self.count = 0
def add_linear(self, linear, var_vals):
self.count += len(linear_to_schedule(linear))
self.count += len(linear.src)
run_linear(linear, var_vals)
def __enter__(self):
if self.preclear:

View file

@ -3,8 +3,7 @@ import gc, unittest, time
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, KernelInfo
from tinygrad.helpers import DEBUG, GlobalCounters, Context
from tinygrad.engine.realize import CompiledRunner, run_linear
from tinygrad.schedule import linear_to_schedule
from tinygrad.engine.realize import compile_linear, run_linear
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
@ -15,17 +14,17 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
else:
assert isinstance(t, UOp), f"can't schedule {t}"
linear, var_vals = Tensor(t).linear_with_vars()
# test lowering all the ExecItems
sched = linear_to_schedule(linear)
for si in sched: si.lower()
kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink])
kernel_cnt = sum((len(call.device) if isinstance(call.device, tuple) else 1)
for call in linear.src if call.src[0].op is Ops.SINK or not filter_sink)
if kernel_cnt != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
if DEBUG >= 3:
for i,s in enumerate(sched):
for i,call in enumerate(linear.src):
print("kernel", i+1)
print(s.ast)
print(call.src[0])
raise KernelCountException(f"{kernel_cnt} != {allowed}")
# test compiling the linear
compile_linear(linear)
return linear, var_vals
def _realize_weights(m):

View file

@ -2,7 +2,6 @@ import unittest
from tinygrad import Tensor, dtypes
from tinygrad.tensor import _METADATA
from tinygrad.engine.realize import capturing
from tinygrad.schedule import linear_to_schedule
from tinygrad.helpers import Context
@unittest.skip("tensor metadata is no longer supported")
@ -99,8 +98,8 @@ class TestTensorMetadata(unittest.TestCase):
capturing.append(type("", (), {"add_linear": lambda _, linear, var_vals: linears.append(linear)})())
try: h.realize()
finally: capturing.clear()
items = [ei for linear in linears for ei in linear_to_schedule(linear)]
return any(m.name == name for ei in items for m in ei.metadata)
calls = [call for linear in linears for call in linear.src]
return any(m.name == name for call in calls for m in call.arg.metadata)
def test_metadata_survives_realize_pending_assign(self):
shared = Tensor.rand(4)