mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
invalid clone try 3+ [PR] (#16679)
This commit is contained in:
parent
b2199c54a3
commit
8b07cca9f7
7 changed files with 50 additions and 27 deletions
|
|
@ -1,8 +1,9 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.function import function
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad import Tensor, GlobalCounters, Device
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, ProgramInfo
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
|
|
@ -549,6 +550,36 @@ class TestFunctionTuple(unittest.TestCase):
|
|||
f(Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()).realize()
|
||||
np.testing.assert_allclose(state.numpy(), [2., 4., 6., 8.])
|
||||
|
||||
def test_custom_kernel_program_invalids_not_captured(self):
|
||||
# llama FP8 kernels are PROGRAM with bare-buffer sinks (no analyzable stores), so the invalids scratch
|
||||
# still must not be captured as an input -- else it is read before the kernel writes it
|
||||
src = "void k(float* restrict data0, float* restrict data1) { for (int i=0;i<4;i++) data0[i]=data1[i]*2.0f; }"
|
||||
lib = Device["CPU"].compiler.compile(src)
|
||||
def prog(C:UOp, A:UOp) -> UOp:
|
||||
sink = UOp.sink(C.base, A.base, arg=KernelInfo(name="k"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="CPU"), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)),
|
||||
arg=ProgramInfo(name="k", global_size=(1, 1, 1), local_size=(1, 1, 1), globals=(0, 1)))
|
||||
|
||||
@function(precompile=True)
|
||||
def f(a:Tensor):
|
||||
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
|
||||
return Tensor.custom_kernel(c, a, fxn=prog)[0]
|
||||
|
||||
a = Tensor([1., 2., 3., 4.], device="CPU").contiguous().realize()
|
||||
np.testing.assert_allclose(f(a).numpy(), [2., 4., 6., 8.])
|
||||
|
||||
def test_invalid_store_into_realized_buffer_is_captured(self):
|
||||
# only fresh invalids() scratch is skipped; a realized buffer is a real input even if an Invalid store
|
||||
# writes into part of it (its other elements must be preserved), so it is still captured
|
||||
state = Tensor([10., 20., 30., 40.], device="CPU").contiguous().realize()
|
||||
@function(precompile=True, allow_implicit=True)
|
||||
def f(a:Tensor):
|
||||
after = state.uop.after(state.uop.shrink(((0, 2),)).store(UOp.const(dtypes.float32, Invalid, shape=(2,))))
|
||||
return Tensor(after).contiguous() + a
|
||||
out = f(Tensor([1., 1., 1., 1.], device="CPU").contiguous().realize())
|
||||
np.testing.assert_allclose(out.numpy(), [11., 21., 31., 41.])
|
||||
|
||||
def test_custom_kernel_precompile_further_compute(self, multi=False, kernel_count:int=2):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
def my_kernel(C:UOp, A:UOp) -> UOp:
|
||||
|
|
|
|||
|
|
@ -1,26 +1,29 @@
|
|||
import functools, itertools, time
|
||||
import functools, time
|
||||
from typing import Generic, TypeVar, Callable, cast, overload
|
||||
from tinygrad.helpers import Context, dedup, getenv, DEBUG
|
||||
from tinygrad.dtype import Invalid
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
|
||||
def add_to_ctx(ctx, x:UOp):
|
||||
if x.buf_uop in ctx[1]: return None
|
||||
ret = x.param_like(len(ctx[0]))
|
||||
ctx[0].append(x)
|
||||
return ret
|
||||
|
||||
pm_transform_unique_const = PatternMatcher([
|
||||
# transform unique consts to LUNIQUE
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
|
||||
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
|
||||
])
|
||||
|
||||
pm_ctx = PatternMatcher([
|
||||
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
|
||||
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
|
||||
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
|
||||
])+pm_transform_unique_const
|
||||
])
|
||||
|
||||
def invalid_outputs(uret:UOp) -> set[UOp]:
|
||||
# invalids() returns fresh write-only scratch: a clone storing CONST(Invalid)
|
||||
# don't capture it as an input; only skip fresh buffers, not realized ones
|
||||
return {u.src[0].buf_uop for u in uret.backward_slice_with_self
|
||||
if u.op is Ops.STORE and u.src[1].base.op is Ops.CONST and u.src[1].base.arg is Invalid
|
||||
and not u.src[0].buf_uop.is_realized}
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
|
|
@ -63,7 +66,7 @@ class _function(Generic[ReturnType]):
|
|||
|
||||
# the BUFFERs that are left are the implicit inputs
|
||||
num_explicit = len(call_uops)
|
||||
uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
|
||||
uret = graph_rewrite(uret, pm_ctx, (call_uops, invalid_outputs(uret)), bottom_up=True, name="get_implicit_inputs")
|
||||
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
|
||||
if not self.allow_implicit:
|
||||
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import cast
|
||||
import math, dataclasses, itertools
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
|
||||
import math, dataclasses
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import sum_acc_dtype
|
||||
|
||||
|
|
@ -42,9 +42,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
|||
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
|
||||
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
|
||||
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
|
||||
# TODO: is this okay here?
|
||||
from tinygrad.function import pm_transform_unique_const
|
||||
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
|
||||
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
||||
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
|
||||
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
|
||||
|
|
|
|||
|
|
@ -560,11 +560,9 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
|||
ret = UOp(Ops.CONST, dtype, arg=dtype.const(b), src=())
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and shape != () and ret.shape != shape else ret
|
||||
@staticmethod
|
||||
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None, unique=True) -> UOp:
|
||||
def invalids(shape:tuple[sint, ...]|None=None, dtype:DTypeLike|None=None, device:str|tuple[str, ...]|None=None) -> UOp:
|
||||
dt = to_dtype(dtype) if dtype is not None else dtypes.from_py(Invalid)
|
||||
ret = UOp(Ops.CONST, dt, arg=dt.const(Invalid),
|
||||
src=(UOp.unique(None if unique is True else unique), UOp(Ops.DEVICE, arg=canonicalize_device(device))))
|
||||
return ret.reshape((1,)*len(shape)).expand(shape) if shape is not None and ret.shape != shape else ret
|
||||
return UOp.const(dt, Invalid, shape=shape).clone(device=device)
|
||||
@staticmethod
|
||||
def range(end:sint, axis_id, axis_type=AxisType.LOOP, *arg, dtype=dtypes.weakint, src=(), **kwargs):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),)+src, arg=(axis_id, axis_type)+arg, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, Invalid
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, multirange_str, range_str, consumer_map_from_toposort
|
||||
from tinygrad.helpers import strip_parens
|
||||
|
|
@ -77,8 +77,6 @@ def render_marg(ctx,x:UOp):
|
|||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
|
||||
pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), arg=Invalid, name="x"),
|
||||
lambda x,u,d: f"UOp.invalids(dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
|
||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.CONST),), name="x"), lambda x: f"UOp.special({x.src[0].arg}, {repr(x.arg)}, dtype={x.dtype})"),
|
||||
|
|
|
|||
|
|
@ -126,9 +126,6 @@ spec_tensor = PatternMatcher([
|
|||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
||||
|
||||
# CONST with a UNIQUE and DEVICE
|
||||
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="c"), lambda c: c.arg is Invalid),
|
||||
|
||||
# BUFFER
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, DType)),
|
||||
|
|
|
|||
|
|
@ -121,7 +121,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
for u in (toposort:=x.toposort()):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue