mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
remove CONST(UNIQUE) (#16383)
This commit is contained in:
parent
d861c50dce
commit
0b88827482
8 changed files with 5 additions and 22 deletions
|
|
@ -3,7 +3,6 @@ from tinygrad import Tensor, dtypes
|
|||
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher, graph_rewrite
|
||||
|
||||
_strip_unique_pm = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(UOp.unique(0), d))),
|
||||
])
|
||||
def _strip_unique(u: UOp) -> UOp: return graph_rewrite(u, _strip_unique_pm)
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class TestCallSchedule(unittest.TestCase):
|
|||
# find the FUNCTION nodes
|
||||
c0 = next(u for u in r0.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
c1 = next(u for u in r1.uop.toposort() if u.op is Ops.FUNCTION)
|
||||
# the function bodies (src[0]) should have identical keys — unique consts must not leak through
|
||||
# the function bodies (src[0]) should have identical keys — local buffer identity must not leak through
|
||||
self.assertEqual(c0.src[0].key, c1.src[0].key)
|
||||
|
||||
def test_precompile_symbolic_2d(self):
|
||||
|
|
|
|||
|
|
@ -181,8 +181,6 @@ def replace_input_buffer(ctx:AllocCtx, b:UOp):
|
|||
pm_finalize_call = PatternMatcher([
|
||||
(UPat(Ops.AFTER, name="x"), finalize_after),
|
||||
(UPat(Ops.COPY, name="x"), lambda ctx,x: ctx.assigns.append(x) if isinstance(x.device, str) and x.device.startswith(("DISK", "TINYFS")) else None),
|
||||
# remove unique from const. TODO: this is copied in function.py
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE, name="d")), name="b"), lambda b,d: b.replace(src=(d,))),
|
||||
])
|
||||
|
||||
pm_replace_buf = PatternMatcher([
|
||||
|
|
|
|||
|
|
@ -10,19 +10,13 @@ def add_to_ctx(ctx, x:UOp):
|
|||
ctx[0].append(x)
|
||||
return ret
|
||||
|
||||
pm_transform_unique_const = PatternMatcher([
|
||||
# transform unique consts to LUNIQUE
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
|
||||
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1]))),
|
||||
])
|
||||
|
||||
pm_ctx = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="x"),
|
||||
lambda ctx,x: x.replace(src=(UOp(Ops.LUNIQUE, arg=next(ctx[1])), x.src[1])) if x.src[0].arg > ctx[2] else add_to_ctx(ctx,x)),
|
||||
(UPat(Ops.BIND, name="x"), add_to_ctx),
|
||||
(UPat((Ops.AFTER, Ops.CONTIGUOUS), name="x"),
|
||||
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) and x.op_in_backward_slice_with_self(Ops.BUFFER) else None),
|
||||
])+pm_transform_unique_const
|
||||
])
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import cast
|
||||
import math, dataclasses, itertools
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
|
||||
import math, dataclasses
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import sum_acc_dtype
|
||||
|
||||
|
|
@ -41,9 +41,6 @@ def call_gradient(ctx:UOp, k:UOp, needed:set[int]) -> tuple[UOp|None, ...]:
|
|||
grad_bodies = [(i, grads[p]) for i in needed if (p:=params.get(i)) is not None and p in grads]
|
||||
bwd_body = UOp.maketuple(*(gb for _, gb in grad_bodies)).substitute(fwd_subs, walk=True)
|
||||
bwd_body, compact_args = _compact_params(bwd_body, (*args, *grad_args, *fwd_outs))
|
||||
# TODO: is this okay here?
|
||||
from tinygrad.function import pm_transform_unique_const
|
||||
bwd_body = graph_rewrite(bwd_body, pm_transform_unique_const, ctx=(None, itertools.count(0)))
|
||||
bwd_call = bwd_body.call(*compact_args, name=(k.arg.name or "")+"_backward", precompile=k.arg.precompile_backward)
|
||||
gb_map = {i: idx for idx, (i, _) in enumerate(grad_bodies)}
|
||||
return (None,) + tuple(bwd_call.gettuple(gb_map[i]) if i in gb_map else None for i in range(len(args)))
|
||||
|
|
|
|||
|
|
@ -79,8 +79,6 @@ def render_marg(ctx,x:UOp):
|
|||
sugar = {Ops.SINK, Ops.END, Ops.STORE, Ops.LOAD, Ops.UNIQUE, Ops.SQRT, Ops.INDEX, Ops.REDUCE, Ops.AFTER, Ops.THREEFRY,
|
||||
Ops.WHERE, Ops.RECIPROCAL, Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.CONTIGUOUS, Ops.BARRIER, Ops.DETACH}
|
||||
pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
|
||||
lambda x,u,d: f"UOp.unique_const({x.arg}, dtype={x.dtype}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
|
|
@ -106,7 +104,6 @@ pm_pyrender_extra = PatternMatcher([
|
|||
# explicit trunc ops: `//` and `%` parse as FLOORDIV/FLOORMOD, so render CDIV/CMOD via .alu()
|
||||
(UPat(Ops.CDIV, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.CDIV, {ctx[x.src[1]]})"),
|
||||
(UPat(Ops.CMOD, name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.alu(Ops.CMOD, {ctx[x.src[1]]})"),
|
||||
# NOTE: only match CONSTs without UNIQUE (len(src)==1), unique_const needs explicit rendering
|
||||
(UPat(set(syms.keys())-{Ops.SUB, Ops.CMPNE, Ops.CDIV, Ops.CMOD}, src=(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),), name="y"), UPat(name="z")),
|
||||
name="x"), lambda ctx,x,y,z: strip_binary_parens(x, str(y.arg), ctx[z], lambda a,b: f"({a}{syms[x.op]}{b})") if y.device==z.device else None),
|
||||
# NOTE: sub doesn't work cause it's written as add/mul
|
||||
|
|
|
|||
|
|
@ -122,9 +122,8 @@ spec_tensor = PatternMatcher([
|
|||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
||||
|
||||
# CONST with a UNIQUE or DEVICE
|
||||
# CONST with a DEVICE
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
(UPat(Ops.CONST, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE))), lambda: True),
|
||||
|
||||
# BUFFER
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.UNIQUE, Ops.LUNIQUE)), UPat(Ops.DEVICE)), name="buf"),
|
||||
|
|
|
|||
|
|
@ -117,7 +117,6 @@ def uop_to_json(data:VizData, x:UOp) -> dict[int, dict]:
|
|||
for u in (toposort:=x.toposort()):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.CONST and len(u.src) and u.src[0].op in {Ops.UNIQUE, Ops.LUNIQUE}: excluded.remove(u)
|
||||
if u.op is Ops.STACK and len(u.src) == 0: excluded.add(u)
|
||||
# exclude RESHAPE/EXPAND that only serve to broadcast a CONST
|
||||
if u.op in {Ops.RESHAPE, Ops.EXPAND} and len(u.src) >= 1 and u.src[0] in excluded and u is not x: excluded.add(u)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue