Revert "remove cpu graph, it's different from the others (#10743)" (#10745)

This reverts commit 3d64a98432.
This commit is contained in:
George Hotz 2025-06-09 22:40:48 -07:00 committed by GitHub
commit 413e223d6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 81 additions and 6 deletions

View file

@ -0,0 +1,67 @@
from typing import cast, TypeVar, Generic, get_args as get_typing_args
import itertools
from tinygrad.helpers import dedup, flatten, DEBUG, to_function_name
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.device import Buffer
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.uop.ops import Variable
from tinygrad.dtype import DType, dtypes
from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.renderer.llvmir import LLVMRenderer, ldt
T = TypeVar('T')
class BatchedGraph(Generic[T], GraphRunner):
def __init__(self, device, jit_cache: list[ExecItem], input_rawbuffers: list[Buffer], var_vals: dict[Variable, int]):
renderer_class = get_typing_args(getattr(self, "__orig_bases__")[0])[0]
if not issubclass(type(device.renderer), renderer_class) and not isinstance(device.renderer, renderer_class): raise GraphException
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.base_bufs = dedup(b.base for ji in jit_cache for b in ji.bufs if b is not None and b not in input_rawbuffers)
self.base_rawbufs = [b._buf for b in self.base_bufs]
targs = [(f"arg{i}", x.dtype.ptr()) for i,x in enumerate(input_rawbuffers)] + \
[(f"cbuf{i}", dtypes.char.ptr()) for i in range(len(self.base_bufs))] + \
sorted([(f"{v.expr}", dtypes.int) for v in var_vals])
code = self._prepare_code(device.renderer, jit_cache, input_rawbuffers, targs)
if DEBUG >= 4: print(code)
self.clprg = device.runtime("batched", device.compiler.compile_cached(code))
def _prepare_code(self, renderer:T, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str: return ""
def __call__(self, rawbufs: list[Buffer], var_vals: dict[Variable, int], wait=False):
return self.clprg(*[x._buf for x in rawbufs], *self.base_rawbufs, *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)], wait=wait)
class CPUGraph(BatchedGraph[ClangRenderer]):
def _prepare_code(self, renderer:ClangRenderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str:
def render_arg(buf):
if buf in input_rawbuffers: return f"arg{input_rawbuffers.index(buf)}"
return f"({renderer.render_dtype(buf.dtype)}*)(cbuf{self.base_bufs.index(buf.base)} + {buf.offset})"
batched = ["void batched("+','.join([f"{renderer.render_dtype(x[1])} {x[0]}" for x in targs])+") {"]
for i, ji in enumerate(jit_cache):
args = [render_arg(buf) for buf in ji.bufs] + [x.expr for x in cast(CompiledRunner, ji.prg).p.vars]
batched.append(f" {to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)});")
batched.append("}")
prep = [renderer._render(cast(CompiledRunner, ji.prg).p.uops or []) for i,ji in enumerate(jit_cache)]
funcs = dedup(renderer._render_body(prep[i][0], *prep[i][1:], cast(CompiledRunner, ji.prg).p.uops,
["static", "__attribute__((always_inline))"]) for i,ji in enumerate(jit_cache))
defines = dedup(itertools.chain.from_iterable(renderer._render_defines(cast(CompiledRunner, ji.prg).p.uops) for ji in jit_cache))
entry = renderer._render_entry("batched", [(t[0], (t[1], False)) for t in targs])
return '\n'.join(defines) + '\n' + '\n'.join([''.join(f) for f in funcs]) + '\n'.join(batched) + '\n' + entry
class LLVMGraph(BatchedGraph[LLVMRenderer]):
def _prepare_code(self, renderer, jit_cache:list[ExecItem], input_rawbuffers:list[Buffer], targs:list[tuple[str, DType]]) -> str:
out = []
for i,ji in enumerate(jit_cache):
args = []
for j,buf in enumerate(cast(list[Buffer], ji.bufs)):
arg = f"%arg{input_rawbuffers.index(buf)}" if buf in input_rawbuffers else f"%b{i}_{j}"
if buf not in input_rawbuffers: out.append(f" {arg} = getelementptr inbounds i8,ptr %cbuf{self.base_bufs.index(buf.base)},i64 {buf.offset}")
args.append(f"{ldt(buf.dtype.ptr())} {arg}")
args += [f"{ldt(x.dtype)} %{x.expr}" for x in cast(CompiledRunner, ji.prg).p.vars]
out.append(f" call void @{to_function_name(cast(CompiledRunner, ji.prg).p.name)}({','.join(args)})")
kernels = dedup(tuple(renderer._render_kernel(cast(CompiledRunner, ji.prg).p.uops, ["internal"]) for i,ji in enumerate(jit_cache)))
kernels += [((), renderer._render_fn("batched", [(f"%{x[0]}", x[1]) for x in targs], out))]
assert flatten(x[0] for x in kernels) == [] # global definitions are not used in CPU mode right now
return "\n".join([x[1] for x in kernels] + [renderer._render_footer(cast(CompiledRunner, ji.prg).p.uops)])

View file

@ -1,4 +1,4 @@
import platform, subprocess, sys
import functools, platform, subprocess, sys
from tinygrad.helpers import capstone_flatdump, getenv
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
from tinygrad.runtime.support.elf import jit_loader
@ -18,5 +18,9 @@ class ClangJITCompiler(Compiler):
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
class CPUDevice(Compiled):
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram)
class ClangDevice(Compiled):
def __init__(self, device:str):
from tinygrad.runtime.graph.cpu import CPUGraph
super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram, functools.partial(CPUGraph, self))
CPUDevice = ClangDevice

View file

@ -1,4 +1,4 @@
import ctypes, platform
import functools, ctypes, platform
from tinygrad.device import Compiled, Compiler, MallocAllocator, CPUProgram
from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG
from tinygrad.renderer.llvmir import LLVMRenderer
@ -70,4 +70,6 @@ class HostLLVMCompiler(LLVMCompiler):
super().__init__(cpu.decode(), feats.decode())
class LLVMDevice(Compiled):
def __init__(self, device:str): super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram)
def __init__(self, device:str):
from tinygrad.runtime.graph.cpu import LLVMGraph
super().__init__(device, MallocAllocator, LLVMRenderer(), HostLLVMCompiler(), CPUProgram, functools.partial(LLVMGraph, self))

View file

@ -16,6 +16,7 @@ from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
from tinygrad.engine.realize import CompiledRunner, BufferXfer
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
from tinygrad.runtime.graph.cpu import CPUGraph
# ***** API *****
@ -170,7 +171,8 @@ class RemoteHandler:
case SessionFree(): del self.sessions[unwrap(c.session)]
case GetProperties():
cls, args = dev.renderer.__reduce__()
graph_cls = graph_class(Device[self.base_device])
# CPUGraph re-renders kernel from uops specified in CompiledRunner, this is not supported
graph_cls = gt if (gt:=graph_class(Device[self.base_device])) is not CPUGraph else None
rp = RemoteProperties(
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args),
graph_supported=graph_cls is not None, graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner),