validate_with_cpu as rewrite (#15938)

* validate_with_cpu as rewrite

* compil

* x

* linter

* moved

* fix
This commit is contained in:
nimlgen 2026-04-26 19:58:53 +03:00 committed by GitHub
commit 96165ff0d1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 81 additions and 29 deletions

View file

@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import Context, from_mv
from tinygrad.dtype import dtypes
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import run_linear
from tinygrad.engine.realize import run_linear, compile_linear
from tinygrad.uop.ops import UOp, Ops, buffers
from test.helpers import needs_second_gpu
@ -44,7 +44,7 @@ def get_buf_uop(buf:Buffer, cache:dict[Buffer,UOp]) -> UOp:
return cache[buf]
def make_graph(graph_cls, calls:list[UOp]):
linear = UOp(Ops.LINEAR, src=tuple(calls))
linear = compile_linear(UOp(Ops.LINEAR, src=tuple(calls)))
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(linear,), arg="graph")
return graph_cls(cf, [])

View file

@ -0,0 +1,40 @@
import unittest
from tinygrad import Tensor, Context, Variable, Device
from test.helpers import needs_second_gpu
class TestValidateWithCPU(unittest.TestCase):
def setUp(self):
self.ctx = Context(VALIDATE_WITH_CPU=1)
self.ctx.__enter__()
def tearDown(self): self.ctx.__exit__(None, None, None)
def test_add(self): self.assertListEqual((Tensor([1.,2,3])+Tensor([4.,5,6])).tolist(), [5.0, 7.0, 9.0])
def test_mul(self): self.assertListEqual((Tensor([1.,2,3])*Tensor([4.,5,6])).tolist(), [4.0, 10.0, 18.0])
def test_sum(self): self.assertEqual(Tensor([1.,2,3,4]).sum().item(), 10.0)
def test_reduce_then_op(self): self.assertEqual((Tensor([1.,2,3,4]).sum() * 2).item(), 20.0)
def test_assign(self):
a = Tensor([1.,2,3]).realize()
a.assign(a + 1).realize()
self.assertListEqual(a.tolist(), [2.0, 3.0, 4.0])
def test_buffer_view(self):
self.assertListEqual((Tensor([1.,2,3,4,5,6,7,8])[2:6] + 1).tolist(), [4.0, 5.0, 6.0, 7.0])
def test_symbolic(self):
i = Variable('i', 1, 10)
ones = Tensor.ones(10).contiguous()
self.assertListEqual((ones[:i.bind(5)] + 1).contiguous()[:5].tolist(), [2.0]*5)
def test_multi_kernel(self):
a = (Tensor([1.,2,3]) + 1).contiguous()
b = (a * 2).contiguous()
self.assertListEqual((b - 1).tolist(), [3.0, 5.0, 7.0])
@needs_second_gpu
def test_sharded(self):
t = Tensor([1.,2,3,4]).shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"), axis=0)
self.assertListEqual((t + 1).tolist(), [2.0, 3.0, 4.0, 5.0])
if __name__ == "__main__":
unittest.main()

View file

@ -20,9 +20,9 @@ class TestHCQUnit(unittest.TestCase):
inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize()
for _ in range(5): f(inp, inp_cpu)
# construct minimal CALL UOps for supports_exec_item
gpu_call = UOp(Ops.SINK).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float))
cpu_call = UOp(Ops.SINK).call(UOp.new_buffer("CPU", 1, dtypes.float))
# construct minimal CALL UOps for supports_exec_item (graphs only see PROGRAMs after compile_linear)
gpu_call = UOp(Ops.PROGRAM).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float))
cpu_call = UOp(Ops.PROGRAM).call(UOp.new_buffer("CPU", 1, dtypes.float))
gpu_devs = [d0]
# local MMIO: GPU works alone and with CPU in batch (cpu_support=True)

View file

@ -62,7 +62,7 @@ def graph_split_rewrite(linear:UOp, max_batch_size:int=0) -> UOp:
def _call_outs_ins(call:UOp) -> tuple[set[int], set[int]]:
non_bind = [s for s in call.src[1:] if s.op is not Ops.BIND]
ast = call.src[0]
if ast.op in (Ops.SINK, Ops.PROGRAM):
if ast.op is Ops.PROGRAM:
prg = get_runner(non_bind[0].device if isinstance(non_bind[0].device, str) else non_bind[0].device[0], call.src[0])
return set(prg.p.outs), set(prg.p.ins)
if ast.op in (Ops.COPY, Ops.BUFFER_VIEW): return {0}, {1}
@ -107,7 +107,7 @@ class GraphRunner(Runner):
replace = [(p, b.arg) for p, b in enumerate(b for b in call.src[1:] if b.op is not Ops.BIND) if b.op is Ops.PARAM]
for dev_idx, (bufs, device_vars) in enumerate(unwrap_multi(call, resolve_params(call, input_uops))):
self.calls.append((dev_idx, call.src[0], [b.ensure_allocated() for b in bufs], device_vars))
self.progs.append(get_runner(bufs[0].device, call.src[0]) if call.src[0].op in (Ops.SINK, Ops.PROGRAM) else None)
self.progs.append(get_runner(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)
self.uop_replace.append(replace)
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
@ -175,14 +175,14 @@ class GraphRunner(Runner):
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
return new_call.src[0].op is Ops.PROGRAM and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
# a marker for your graph supporting multiple devices of the same type
class MultiGraphRunner(GraphRunner):
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
# Devices must be the same type
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
return new_call.src[0].op in (Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
ReturnType = TypeVar('ReturnType')
@dataclass

View file

@ -4,6 +4,7 @@ from dataclasses import dataclass, replace, field
from tinygrad.helpers import colored, DEBUG, GlobalCounters, ansilen, NOOPT, all_int, Metadata, TRACEMETA, TracingKey
from tinygrad.helpers import BEAM, DEVECTORIZE, size_to_str, time_to_str, VALIDATE_WITH_CPU, cpu_profile, PROFILE, ProfilePointEvent, cpu_events
from tinygrad.helpers import prod, EMULATED_DTYPES, flatten
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, buffers, graph_rewrite
from tinygrad.device import Device, Buffer, MultiBuffer
from tinygrad.renderer import ProgramSpec, Estimates
@ -174,19 +175,18 @@ def exec_kernel(ctx:ExecContext, call, ast):
prg = get_runner(bufs[0].device, ast)
prg_bufs = [bufs[i].ensure_allocated() for i in prg.p.globals]
if VALIDATE_WITH_CPU and ast.op is Ops.SINK:
cpu_bufs = [Buffer("CPU", b.size, b.dtype).ensure_allocated().copyin(b.ensure_allocated().as_memoryview()) for b in bufs]
with track_stats(ctx, call, prg.device, prg.display_name, prg_bufs, var_vals,
outputs=tuple(prg.p.outs), inputs=tuple(prg.p.ins), first_run=prg.first_run) as timing:
timing[0] = prg(prg_bufs, var_vals, wait=DEBUG >= 2)
prg.first_run = False
if VALIDATE_WITH_CPU and ast.op is Ops.SINK:
import numpy as np
cpu_prg = get_runner("CPU", ast)
cpu_prg([cpu_bufs[i] for i in cpu_prg.p.globals], var_vals, wait=False)
for i in prg.p.outs: np.testing.assert_allclose(prg_bufs[i].numpy(), cpu_bufs[i].numpy(), rtol=1e-3, atol=1e-3)
def exec_validate(ctx:ExecContext, call, ast):
import numpy as np
for bufs, device_vars in unwrap_multi(call, resolve_params(call, ctx.input_uops)):
cpu_bufs, dev_bufs = bufs[:len(bufs)//2], bufs[len(bufs)//2:]
cpu_prg = get_runner("CPU", ast.src[0])
cpu_prg([cpu_bufs[i].ensure_allocated() for i in cpu_prg.p.globals], {**ctx.var_vals, **device_vars}, wait=False)
for i in cpu_prg.p.outs: np.testing.assert_allclose(dev_bufs[i].ensure_allocated().numpy(), cpu_bufs[i].numpy(), rtol=1e-3, atol=1e-3)
def exec_encdec(ctx:ExecContext, call, ast):
bufs = [cast(Buffer, b.buffer).ensure_allocated() for b in resolve_params(call, ctx.input_uops)]
@ -202,6 +202,19 @@ def exec_graph(ctx:ExecContext, call, cf):
with track_stats(ctx, call, runner.device, runner.display_name, bufs, ctx.var_vals) as t:
t[0] = runner(bufs, ctx.var_vals, wait=DEBUG >= 2, input_uops=ctx.input_uops) # type: ignore[call-arg]
# flatten LINEAR-in-LINEAR: any nested LINEAR child gets inlined into its parent's src
pm_flatten_linear = PatternMatcher([
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="lin"),
lambda lin: lin.replace(src=tuple(flatten(c.src if c.op is Ops.LINEAR else (c,) for c in lin.src)))),
])
def _validate(call:UOp, sink:UOp) -> UOp:
params = tuple(p for p in call.src[1:] if p.op is not Ops.BIND)
shadows = tuple(UOp.new_buffer(("CPU",)*len(p.device) if isinstance(p.device, tuple) else "CPU", prod(p.max_shape), p.dtype.base) for p in params)
copies = tuple(p.copy_to_device(s.device).call(s, p) for s, p in zip(shadows, params))
return UOp(Ops.LINEAR, src=copies + (call, UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(sink,), arg="validate").call(*shadows, *params)))
pm_validate = PatternMatcher([(UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True), _validate)]) + pm_flatten_linear
# ctx is beam value
pm_beam = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.SINK, name="sink"),), name="call", allow_any_len=True),
@ -216,16 +229,18 @@ pm_compile = PatternMatcher([
pm_exec = PatternMatcher([
(UPat(Ops.CALL, src=(UPat(Ops.BUFFER_VIEW, name="ast"),), name="call", allow_any_len=True), exec_view),
(UPat(Ops.CALL, src=(UPat(Ops.COPY, name="ast"),), name="call", allow_any_len=True), exec_copy),
(UPat(Ops.CALL, src=(UPat((Ops.PROGRAM, Ops.SINK), name="ast"),), name="call", allow_any_len=True), exec_kernel),
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="ast"),), name="call", allow_any_len=True), exec_kernel),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="ast"),), name="call", allow_any_len=True), exec_encdec),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"),), name="call", allow_any_len=True), exec_graph),
(UPat(Ops.CALL, src=(UPat(Ops.CUSTOM_FUNCTION, arg="validate", name="ast"),), name="call", allow_any_len=True), exec_validate),
])
def compile_linear(linear:UOp, beam=0) -> UOp:
def compile_linear(linear:UOp, beam=0, validate=False) -> UOp:
if validate: linear = graph_rewrite(linear, pm_validate, name="validate", walk=True)
if (beam_val:=(beam or BEAM.value)) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=beam_val, walk=True)
return graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True) if not VALIDATE_WITH_CPU else linear
return graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True)
def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), do_update_stats=True, jit=False):
if not jit: linear = compile_linear(linear)
if not jit: linear = compile_linear(linear, validate=VALIDATE_WITH_CPU)
ctx = ExecContext(var_vals or {}, input_uops, do_update_stats, jit)
for call in linear.src: pm_exec.rewrite(call, ctx)

View file

@ -15,7 +15,7 @@ class CUDAGraph(MultiGraphRunner):
self.graph = init_c_var(cuda.CUgraph, lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
for (dev_idx, ast, bufs, device_vars), prg in zip(self.calls, self.progs):
if ast.op in (Ops.SINK, Ops.PROGRAM):
if ast.op is Ops.PROGRAM:
assert prg is not None
global_size, local_size = prg.p.launch_dims({v: 0 for v in self.vars})

View file

@ -333,4 +333,4 @@ class HCQGraph(MultiGraphRunner):
# MOCKGPU is not supported, since it can't execute commands in parallel
is_xfer = len(set(type(d) for d in all_devs)) == 1 and hasattr(alc:=all_devs[0].allocator, '_transfer') and alc.supports_transfer
return is_xfer or (all_devs[0].hw_copy_queue_t is not None and not getattr(all_devs[0], 'iface', None).__class__.__name__.startswith("MOCK"))
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM)
return new_call.src[0].op is Ops.PROGRAM

View file

@ -2,7 +2,7 @@ import time, inspect
from collections import deque
from tinygrad.uop.ops import UOp, Ops, UOpMetaClass, track_rewrites, graph_rewrite, gate_kernel_sink, KernelInfo
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, flatten, partition
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE, BASEDIR, partition
# **** schedule linearizer
@ -67,7 +67,7 @@ def create_schedule(sched_sink:UOp) -> UOp:
return UOp(Ops.LINEAR, src=tuple(linearized))
from tinygrad.schedule.memory import memory_plan_rewrite
from tinygrad.engine.realize import capturing
from tinygrad.engine.realize import capturing, pm_flatten_linear
from tinygrad.schedule.rangeify import get_kernel_graph
from tinygrad.helpers import CAPTURING
from tinygrad.uop.ops import PatternMatcher, UPat
@ -86,10 +86,7 @@ pm_resolve_linear_call = PatternMatcher([
# call LINEAR is resolved here
(UPat(Ops.CALL, src=(UPat(Ops.LINEAR),), name="linear_call", allow_any_len=True), lambda linear_call:
graph_rewrite(linear_call.src[0], pm_post_sched_cache, ctx=({}, linear_call.src[1:]), walk=True, name="params to buffers")),
# LINEAR on LINEAR
(UPat(Ops.LINEAR, custom_early_reject={Ops.LINEAR}, name="x"),
lambda x: x.replace(src=tuple(flatten(x.src if x.op is Ops.LINEAR else (x,) for x in x.src)))),
])
])+pm_flatten_linear
schedule_cache: dict[bytes, UOp] = {}
# ctx is just for DEBUG on inner