Merge branch 'master' into codegen2

This commit is contained in:
George Hotz 2026-06-19 18:29:13 -07:00 committed by GitHub
commit 0decf136fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 50 additions and 27 deletions

View file

@ -1,8 +1,9 @@
import numpy as np
import unittest
from tinygrad.function import function
from tinygrad import Tensor, GlobalCounters
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad import Tensor, GlobalCounters, Device
from tinygrad.dtype import dtypes, Invalid
from tinygrad.uop.ops import UOp, Ops, KernelInfo, ProgramInfo
class TestFunction(unittest.TestCase):
def test_simple(self):
@ -549,6 +550,36 @@ class TestFunctionTuple(unittest.TestCase):
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
def test_custom_kernel_program_invalids_not_captured(self):
# llama FP8 kernels are PROGRAM with bare-buffer sinks (no analyzable stores), so the invalids scratch
# still must not be captured as an input -- else it is read before the kernel writes it
src = "void k(float* restrict data0, float* restrict data1) { for (int i=0;i<4;i++) data0[i]=data1[i]*2.0f; }"
lib = Device["CPU"].compiler.compile(src)
def prog(C:UOp, A:UOp) -> UOp:
sink = UOp.sink(C.base, A.base, arg=KernelInfo(name="k"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="CPU"), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)),
arg=ProgramInfo(name="k", global_size=(1, 1, 1), local_size=(1, 1, 1), globals=(0, 1)))
@function(precompile=True)
def f(a:Tensor):
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
return Tensor.custom_kernel(c, a, fxn=prog)[0]
a = Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()
np.testing.assert_allclose(f(a).numpy(), [2., 4., 6., 8.])
def test_invalid_store_into_realized_buffer_is_captured(self):
# only fresh invalids() scratch is skipped; a realized buffer is a real input even if an Invalid store
# writes into part of it (its other elements must be preserved), so it is still captured
state = Tensor([10., 20., 30., 40.], device="CPU").contiguous().realize()
@function(precompile=True, allow_implicit=True)
def f(a:Tensor):
after = state.uop.after(state.uop.shrink(((0, 2),)).store(UOp.const(dtypes.float32, Invalid, shape=(2,))))
return Tensor(after).contiguous() + a
out = f(Tensor([1., 1., 1., 1.], device="CPU").contiguous().realize())
np.testing.assert_allclose(out.numpy(), [11., 21., 31., 41.])
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
devs = ("CPU:0", "CPU:1")
def my_kernel(C:UOp, A:UOp) -> UOp:

View file

@ -1,26 +1,29 @@
import functools, itertools, time
import functools, time
from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv, DEBUG
from tinygrad.dtype import Invalid
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
def add_to_ctx(ctx, x:UOp):
if x.buf_uop in ctx[1]: return None
ret = x.param_like(len(ctx[0]))
ctx[0].append(x)
return ret
pm_transform_unique_const = PatternMatcher([
# transform unique consts to LUNIQUE
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
])
pm_ctx = PatternMatcher([
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
])+pm_transform_unique_const
])
def invalid_outputs(uret:UOp) -> set[UOp]:
# invalids() returns fresh write-only scratch: a clone storing CONST(Invalid)
# don't capture it as an input; only skip fresh buffers, not realized ones
return {u.src[0].buf_uop for u in uret.backward_slice_with_self
if u.op is Ops.STORE and u.src[1].base.op is Ops.CONST and u.src[1].base.arg is Invalid
and not u.src[0].buf_uop.is_realized}
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
@ -63,7 +66,7 @@ class _function(Generic[ReturnType]):
# the BUFFERs that are left are the implicit inputs
num_explicit = len(call_uops)
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
uret = graph_rewrite(uret, pm_ctx, (call_uops, invalid_outputs(uret)), bottom_up=True, name="get_implicit_inputs")
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
if not self.allow_implicit:
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]

View file

@ -1,6 +1,6 @@
from typing import cast
import math, dataclasses, itertools
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
import math, dataclasses
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
from tinygrad.dtype import sum_acc_dtype
@ -42,9 +42,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
# TODO: is this okay here?
from tinygrad.function import pm_transform_unique_const
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))

View file

@ -561,11 +561,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
@staticmethod
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp:
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid),
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
return UOp.const(dt, Invalid, shape=shape).clone(device=device)
@staticmethod
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)

View file

@ -1,5 +1,5 @@
from typing import cast
from tinygrad.dtype import dtypes, Invalid
from tinygrad.dtype import dtypes
from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
from tinygrad.helpers import strip_parens
@ -77,8 +77,6 @@ def render_marg(ctx,x:UOp):
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),

View file

@ -126,9 +126,6 @@ spec_tensor = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
# CONST with a UNIQUE and DEVICE
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
# BUFFER
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),

View file

@ -121,7 +121,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)