PADTO pads Invalids (#16562)

This commit is contained in:
chenyu 2026-06-11 16:54:26 -04:00 committed by GitHub
commit 5f1e2d3900
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 22 additions and 29 deletions

View file

@ -224,7 +224,6 @@ class TestKernelOpts(unittest.TestCase):
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
# can pad sum reduce axis if there's no unsafe ops prior to sum
for axis in (0, 1):
helper_linearizer_opt(a.sum(), [[Opt(OptOps.PADTO, axis, 32)],])
helper_linearizer_opt(a.sum(0), [[Opt(OptOps.PADTO, axis, 32)],])
@ -240,22 +239,16 @@ class TestKernelOpts(unittest.TestCase):
helper_linearizer_opt(a.sum().exp(), [[Opt(OptOps.PADTO, 0, 32)],])
helper_linearizer_opt(a.sum(0).exp(), [[Opt(OptOps.PADTO, 1, 32)],])
def test_padto_sum_not_ok(self):
def test_padto_sum(self):
N = 18
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))).exp()
# exp is not safe to pad
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
helper_linearizer_opt(a.exp().sum(), [[Opt(OptOps.PADTO, 0, 32)],])
helper_linearizer_opt(a.exp().sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
b = a < 1
# lt is not safe to pad
with self.assertRaises(KernelOptError):
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
helper_linearizer_opt(b.sum(), [[Opt(OptOps.PADTO, 0, 32)],])
helper_linearizer_opt(b.sum(0), [[Opt(OptOps.PADTO, 1, 32)],])
def test_padto_max(self):
N = 18
@ -271,11 +264,8 @@ class TestKernelOpts(unittest.TestCase):
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),],
])
# cannot pad max kernel on reduce
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
helper_linearizer_opt(a.max(), [[Opt(OptOps.PADTO, 0, 32)],])
helper_linearizer_opt(a.max(0), [[Opt(OptOps.PADTO, 1, 32)],])
def test_padto_where(self):
Tensor.manual_seed(0)

View file

@ -13,7 +13,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
# import all pattern matchers here
from tinygrad.codegen.gpudims import pm_add_gpudims
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing, symbolic, pm_move_where_on_load, pm_clean_up_group_sink, pm_remove_invalid
from tinygrad.uop.decompositions import get_late_rewrite_patterns, get_transcendental_patterns, pm_dtype_decomps
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize_buf_and_index, devectorize_alu, pm_reduce, \
@ -89,8 +89,8 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
# **** optimizations are done, now we lower to actual code ****
# add loads
sink = graph_rewrite(sink, pm_add_loads, name="** add loads (code)")
# add loads and remove invalids
sink = graph_rewrite(sink, pm_add_loads+pm_remove_invalid, name="** add loads (code)")
# create image buffers
if IMAGE and ren.target.device in {"QCOM", "CL", "PYTHON", "NULL"}:

View file

@ -341,6 +341,9 @@ def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
return sink.substitute(subs) if subs else None
pm_reduce = PatternMatcher([
# invalid -> identity element
(UPat(Ops.REDUCE, src=(invalid_gate,), allow_any_len=True, name="red"), lambda red,cond,x,i:
red.replace(src=(cond.where(x, identity_element(red.arg[0], x.dtype.scalar())),)+red.src[1:])),
# REDUCE -> DEFINE_ACC+ASSIGN, then merge ENDs with same range
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
(UPat(Ops.SINK, name="sink"), merge_reduce_ends),

View file

@ -2,10 +2,10 @@ from __future__ import annotations
import math, itertools
from collections import defaultdict
from typing import cast, Final
from tinygrad.uop.ops import Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, GroupOp, remove_all_tags
from tinygrad.uop.ops import Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, remove_all_tags
from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes
from tinygrad.dtype import dtypes, Invalid
from tinygrad.helpers import colored, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.helpers import ALLOW_TF32, count, Context
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
@ -188,17 +188,16 @@ class Scheduler:
check(rng.src[0].op is Ops.CONST, "only pad const axes")
check(rng.arg[-1] not in {AxisType.UPCAST, AxisType.UNROLL}, "cannot pad upcasted") # TODO: why is this wrong?
check(rng.arg[-1] is not AxisType.THREAD, "cannot pad thread")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and rng.arg[-1] in (AxisType.GROUP_REDUCE, AxisType.REDUCE):
check(r.arg[0] is Ops.ADD and not r.op_in_backward_slice_with_self(*GroupOp.UnsafePad), f"cannot pad {r}")
new_sz = round_up(int(rng.vmax+1), cast(int, opt.arg))
check(rng.vmax+1 > new_sz//4, "pad adds more than quadruple the work")
replaced_rng = UOp.range(new_sz, *rng.arg)
replaces = {rng:replaced_rng}
valid = replaced_rng < rng.vmax+1
store_targets = {s.src[0] for s in self.ast.backward_slice_with_self if s.op is Ops.STORE}
for b in self.bufs:
if rng in (i:=b.src[1].get_idx()).backward_slice_with_self:
replaces[b] = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
nb = b.replace(src=(b.src[0],(valid&b.src[1].get_valid()).where(i, UOp.invalid())))
replaces[b] = nb if b in store_targets else valid.where(nb, UOp.const(b.dtype, Invalid))
self.ast = self.ast.substitute(replaces, f"padto {rng.arg[:-1]} {opt.arg}")
elif opt.op is OptOps.SWAP:
try:

View file

@ -142,7 +142,4 @@ class GroupOp:
# These can change the dtype to bool
Comparison = {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIPROCAL, Ops.LOG2, Ops.EXP2, Ops.CDIV, Ops.POW, Ops.FLOORDIV}
All = set(Ops)

View file

@ -86,6 +86,10 @@ propagate_invalid = PatternMatcher([
lambda x,i: x.src[1] if len(x.src) > 1 else x.const_like(0)),
])
pm_remove_invalid = PatternMatcher([
(UPat(Ops.CONST, arg=Invalid, name="i"), lambda i: i.const_like(0) if i.dtype.scalar() is not dtypes.weakint else None),
])
symbolic_simple = propagate_invalid + PatternMatcher([
# ** self folding **
(UPat.var("x") + 0, lambda x: x), # x+0 -> x